Skip to content

Commit 8fa8fa7

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI Client(evals) - add user-facing generate_loss_clusters with LRO polling and replay tests
PiperOrigin-RevId: 893874547
1 parent 09794ba commit 8fa8fa7

File tree

5 files changed

+387
-1
lines changed

5 files changed

+387
-1
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16+
17+
from tests.unit.vertexai.genai.replays import pytest_helper
18+
from vertexai import types
19+
import pytest
20+
21+
22+
def test_gen_loss_clusters(client):
23+
"""Tests that generate_loss_clusters() returns GenerateLossClustersResponse."""
24+
eval_result = types.EvaluationResult()
25+
response = client.evals.generate_loss_clusters(
26+
eval_result=eval_result,
27+
config=types.LossAnalysisConfig(
28+
metric="multi_turn_task_success_v1",
29+
candidate="travel-agent",
30+
),
31+
)
32+
assert isinstance(response, types.GenerateLossClustersResponse)
33+
assert len(response.results) == 1
34+
result = response.results[0]
35+
assert result.config.metric == "multi_turn_task_success_v1"
36+
assert result.config.candidate == "travel-agent"
37+
assert len(result.clusters) == 2
38+
assert result.clusters[0].cluster_id == "cluster-1"
39+
assert result.clusters[0].taxonomy_entry.l1_category == "Tool Calling"
40+
assert (
41+
result.clusters[0].taxonomy_entry.l2_category == "Missing Tool Invocation"
42+
)
43+
assert result.clusters[0].item_count == 3
44+
assert result.clusters[1].cluster_id == "cluster-2"
45+
assert result.clusters[1].taxonomy_entry.l1_category == "Hallucination"
46+
assert result.clusters[1].item_count == 2
47+
48+
49+
pytest_plugins = ("pytest_asyncio",)
50+
51+
52+
@pytest.mark.asyncio
53+
async def test_gen_loss_clusters_async(client):
54+
"""Tests that generate_loss_clusters() async returns GenerateLossClustersResponse."""
55+
eval_result = types.EvaluationResult()
56+
response = await client.aio.evals.generate_loss_clusters(
57+
eval_result=eval_result,
58+
config=types.LossAnalysisConfig(
59+
metric="multi_turn_task_success_v1",
60+
candidate="travel-agent",
61+
),
62+
)
63+
assert isinstance(response, types.GenerateLossClustersResponse)
64+
assert len(response.results) == 1
65+
result = response.results[0]
66+
assert result.config.metric == "multi_turn_task_success_v1"
67+
assert len(result.clusters) == 2
68+
assert result.clusters[0].cluster_id == "cluster-1"
69+
assert result.clusters[1].cluster_id == "cluster-2"
70+
71+
72+
pytestmark = pytest_helper.setup(
73+
file=__file__,
74+
globals_for_file=globals(),
75+
test_method="evals.generate_loss_clusters",
76+
)

tests/unit/vertexai/genai/test_evals.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from google.cloud.aiplatform import initializer as aiplatform_initializer
3030
from vertexai import _genai
3131
from vertexai._genai import _evals_data_converters
32+
from vertexai._genai import _evals_utils
3233
from vertexai._genai import _evals_metric_handlers
3334
from vertexai._genai import _evals_visualization
3435
from vertexai._genai import _evals_metric_loaders
@@ -265,6 +266,100 @@ def test_t_inline_results(self):
265266
assert payload[0]["candidate_results"][0]["score"] == 0.0
266267

267268

269+
class TestLossAnalysis:
270+
"""Unit tests for loss analysis types and visualization."""
271+
272+
def test_response_structure(self):
273+
response = common_types.GenerateLossClustersResponse(
274+
analysis_time="2026-04-01T10:00:00Z",
275+
results=[
276+
common_types.LossAnalysisResult(
277+
config=common_types.LossAnalysisConfig(
278+
metric="multi_turn_task_success_v1",
279+
candidate="travel-agent",
280+
),
281+
analysis_time="2026-04-01T10:00:00Z",
282+
clusters=[
283+
common_types.LossCluster(
284+
cluster_id="cluster-1",
285+
taxonomy_entry=common_types.LossTaxonomyEntry(
286+
l1_category="Tool Calling",
287+
l2_category="Missing Tool Invocation",
288+
description="The agent failed to invoke a required tool.",
289+
),
290+
item_count=3,
291+
),
292+
common_types.LossCluster(
293+
cluster_id="cluster-2",
294+
taxonomy_entry=common_types.LossTaxonomyEntry(
295+
l1_category="Hallucination",
296+
l2_category="Hallucination of Action",
297+
description="Verbally confirmed action without tool.",
298+
),
299+
item_count=2,
300+
),
301+
],
302+
)
303+
],
304+
)
305+
assert len(response.results) == 1
306+
assert response.analysis_time == "2026-04-01T10:00:00Z"
307+
result = response.results[0]
308+
assert result.config.metric == "multi_turn_task_success_v1"
309+
assert len(result.clusters) == 2
310+
assert result.clusters[0].cluster_id == "cluster-1"
311+
assert result.clusters[0].item_count == 3
312+
assert result.clusters[1].cluster_id == "cluster-2"
313+
314+
def test_response_show_with_results(self, capsys):
315+
response = common_types.GenerateLossClustersResponse(
316+
results=[
317+
common_types.LossAnalysisResult(
318+
config=common_types.LossAnalysisConfig(
319+
metric="test_metric",
320+
candidate="test-candidate",
321+
),
322+
clusters=[
323+
common_types.LossCluster(
324+
cluster_id="c1",
325+
taxonomy_entry=common_types.LossTaxonomyEntry(
326+
l1_category="Cat1",
327+
l2_category="SubCat1",
328+
),
329+
item_count=5,
330+
),
331+
],
332+
)
333+
],
334+
)
335+
response.show()
336+
captured = capsys.readouterr()
337+
assert "test_metric" in captured.out
338+
assert "c1" in captured.out
339+
340+
def test_loss_analysis_result_show(self, capsys):
341+
result = common_types.LossAnalysisResult(
342+
config=common_types.LossAnalysisConfig(
343+
metric="test_metric",
344+
candidate="test-candidate",
345+
),
346+
clusters=[
347+
common_types.LossCluster(
348+
cluster_id="c1",
349+
taxonomy_entry=common_types.LossTaxonomyEntry(
350+
l1_category="DirectCat",
351+
l2_category="DirectSubCat",
352+
),
353+
item_count=7,
354+
),
355+
],
356+
)
357+
result.show()
358+
captured = capsys.readouterr()
359+
assert "test_metric" in captured.out
360+
assert "c1" in captured.out
361+
362+
268363
class TestEvals:
269364
"""Unit tests for the GenAI client."""
270365

vertexai/_genai/_evals_utils.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
"""Utility functions for evals."""
1616

1717
import abc
18+
import asyncio
19+
import json
1820
import logging
1921
import os
20-
import json
22+
import time
2123
from typing import Any, Optional, Union
2224

2325
from google.genai._api_client import BaseApiClient
@@ -366,6 +368,112 @@ def _postprocess_user_scenarios_response(
366368
)
367369

368370

371+
def _display_loss_analysis_result(
372+
result: types.LossAnalysisResult,
373+
) -> None:
374+
"""Displays a LossAnalysisResult as a formatted pandas DataFrame."""
375+
metric = result.config.metric if result.config else None
376+
candidate = result.config.candidate if result.config else None
377+
rows = []
378+
for cluster in result.clusters or []:
379+
entry = cluster.taxonomy_entry
380+
row = {
381+
"metric": metric,
382+
"candidate": candidate,
383+
"cluster_id": cluster.cluster_id,
384+
"l1_category": entry.l1_category if entry else None,
385+
"l2_category": entry.l2_category if entry else None,
386+
"description": entry.description if entry else None,
387+
"item_count": cluster.item_count,
388+
}
389+
rows.append(row)
390+
391+
if not rows:
392+
print("No loss clusters found.") # pylint: disable=print-function
393+
return
394+
395+
df = pd.DataFrame(rows)
396+
try:
397+
from IPython.display import display # pylint: disable=g-import-not-at-top
398+
399+
display(df)
400+
except ImportError:
401+
print(df.to_string()) # pylint: disable=print-function
402+
403+
404+
405+
406+
407+
def _poll_operation(
408+
api_client: BaseApiClient,
409+
operation: types.GenerateLossClustersOperation,
410+
poll_interval_seconds: float = 5.0,
411+
) -> types.GenerateLossClustersOperation:
412+
"""Polls a long-running operation until completion.
413+
414+
Args:
415+
api_client: The API client to use for polling.
416+
operation: The initial operation returned from the API call.
417+
poll_interval_seconds: Time between polls.
418+
419+
Returns:
420+
The completed operation.
421+
"""
422+
if operation.done:
423+
return operation
424+
start_time = time.time()
425+
while True:
426+
response = api_client.request("get", operation.name, {}, None)
427+
response_dict = {} if not response.body else json.loads(response.body)
428+
polled = types.GenerateLossClustersOperation._from_response(
429+
response=response_dict, kwargs={}
430+
)
431+
if polled.done:
432+
return polled
433+
elapsed = int(time.time() - start_time)
434+
logger.info(
435+
"Loss analysis operation still running... Elapsed time: %d seconds",
436+
elapsed,
437+
)
438+
time.sleep(poll_interval_seconds)
439+
440+
441+
async def _poll_operation_async(
442+
api_client: BaseApiClient,
443+
operation: types.GenerateLossClustersOperation,
444+
poll_interval_seconds: float = 5.0,
445+
) -> types.GenerateLossClustersOperation:
446+
"""Polls a long-running operation until completion (async).
447+
448+
Args:
449+
api_client: The API client to use for polling.
450+
operation: The initial operation returned from the API call.
451+
poll_interval_seconds: Time between polls.
452+
453+
Returns:
454+
The completed operation.
455+
"""
456+
if operation.done:
457+
return operation
458+
start_time = time.time()
459+
while True:
460+
response = await api_client.async_request(
461+
"get", operation.name, {}, None
462+
)
463+
response_dict = {} if not response.body else json.loads(response.body)
464+
polled = types.GenerateLossClustersOperation._from_response(
465+
response=response_dict, kwargs={}
466+
)
467+
if polled.done:
468+
return polled
469+
elapsed = int(time.time() - start_time)
470+
logger.info(
471+
"Loss analysis operation still running... Elapsed time: %d seconds",
472+
elapsed,
473+
)
474+
await asyncio.sleep(poll_interval_seconds)
475+
476+
369477
def _validate_dataset_agent_data(
370478
dataset: types.EvaluationDataset,
371479
inference_configs: Optional[dict[str, Any]] = None,

0 commit comments

Comments
 (0)