Skip to content

Commit 6daa361

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI Client(evals) - add core data models and code-gen mapping for auto-loss analysis
PiperOrigin-RevId: 892978545
1 parent 8710b02 commit 6daa361

File tree

5 files changed

+717
-1
lines changed

5 files changed

+717
-1
lines changed

tests/unit/vertexai/genai/test_evals.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
from vertexai._genai import _evals_metric_loaders
3535
from vertexai._genai import _gcs_utils
3636
from vertexai._genai import _observability_data_converter
37+
from vertexai._genai import _transformers
3738
from vertexai._genai import evals
3839
from vertexai._genai import types as vertexai_genai_types
40+
from vertexai._genai.types import common as common_types
3941
from google.genai import client
4042
from google.genai import errors as genai_errors
4143
from google.genai import types as genai_types
@@ -218,6 +220,51 @@ def test_get_api_client_with_none_location(
218220
mock_vertexai_client.assert_not_called()
219221

220222

223+
class TestTransformers:
224+
"""Unit tests for transformers."""
225+
226+
def test_t_inline_results(self):
227+
eval_result = common_types.EvaluationResult(
228+
eval_case_results=[
229+
common_types.EvalCaseResult(
230+
eval_case_index=0,
231+
response_candidate_results=[
232+
common_types.ResponseCandidateResult(
233+
response_index=0,
234+
metric_results={
235+
"tool_use_quality": common_types.EvalCaseMetricResult(
236+
score=0.0,
237+
explanation="Failed tool use",
238+
)
239+
},
240+
)
241+
],
242+
)
243+
],
244+
evaluation_dataset=[
245+
common_types.EvaluationDataset(
246+
eval_cases=[
247+
common_types.EvalCase(
248+
prompt=genai_types.Content(
249+
parts=[genai_types.Part(text="test prompt")]
250+
)
251+
)
252+
]
253+
)
254+
],
255+
metadata=common_types.EvaluationRunMetadata(candidate_names=["gemini-pro"]),
256+
)
257+
258+
payload = _transformers.t_inline_results([eval_result])
259+
260+
assert len(payload) == 1
261+
assert payload[0]["metric"] == "tool_use_quality"
262+
assert payload[0]["request"]["prompt"]["text"] == "test prompt"
263+
assert len(payload[0]["candidate_results"]) == 1
264+
assert payload[0]["candidate_results"][0]["candidate"] == "gemini-pro"
265+
assert payload[0]["candidate_results"][0]["score"] == 0.0
266+
267+
221268
class TestEvals:
222269
"""Unit tests for the GenAI client."""
223270

@@ -1754,7 +1801,7 @@ def test_run_inference_with_litellm_openai_request_format(
17541801
self,
17551802
mock_api_client_fixture,
17561803
):
1757-
"""Tests inference with LiteLLM where the row contains an chat completion request body."""
1804+
"""Tests inference with LiteLLM where the row contains a chat completion request body."""
17581805
with mock.patch(
17591806
"vertexai._genai._evals_common.litellm"
17601807
) as mock_litellm, mock.patch(

vertexai/_genai/_transformers.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from google.genai._common import get_value_by_path as getv
2121

2222
from . import _evals_constant
23+
from . import _evals_data_converters
2324
from . import types
2425

2526
_METRIC_RES_NAME_RE = r"^projects/[^/]+/locations/[^/]+/evaluationMetrics/[^/]+$"
@@ -257,3 +258,118 @@ def t_metric_for_registry(
257258
raise ValueError(f"Unsupported metric type: {metric_name}")
258259

259260
return metric_payload_item
261+
262+
263+
def t_inline_results(
264+
eval_results: list[Any],
265+
) -> list[dict[str, Any]]:
266+
"""Transforms a list of SDK EvaluationResults into API EvaluationResults."""
267+
api_results: list[dict[str, Any]] = []
268+
269+
for eval_result in eval_results:
270+
metadata = getv(eval_result, ["metadata"])
271+
candidate_names = getv(metadata, ["candidate_names"]) if metadata else []
272+
candidate_names = candidate_names or []
273+
274+
eval_dataset = getv(eval_result, ["evaluation_dataset"])
275+
eval_cases: list[Any] = []
276+
if isinstance(eval_dataset, list) and eval_dataset:
277+
eval_cases = getv(eval_dataset[0], ["eval_cases"]) or []
278+
279+
eval_case_results = getv(eval_result, ["eval_case_results"]) or []
280+
281+
for case_result in eval_case_results:
282+
case_idx = getv(case_result, ["eval_case_index"]) or 0
283+
284+
eval_case = None
285+
if 0 <= case_idx < len(eval_cases):
286+
eval_case = eval_cases[case_idx]
287+
288+
prompt_payload = {}
289+
if eval_case:
290+
agent_data = getv(eval_case, ["agent_data"])
291+
prompt = getv(eval_case, ["prompt"])
292+
293+
if agent_data:
294+
if hasattr(agent_data, "model_dump"):
295+
prompt_payload["agent_data"] = agent_data.model_dump()
296+
else:
297+
prompt_payload["agent_data"] = agent_data
298+
elif prompt:
299+
text = _evals_data_converters._get_content_text(
300+
prompt
301+
) # pylint: disable=protected-access
302+
if text:
303+
prompt_payload["text"] = str(text)
304+
305+
cand_results = getv(case_result, ["response_candidate_results"]) or []
306+
for resp_cand_result in cand_results:
307+
resp_idx = getv(resp_cand_result, ["response_index"]) or 0
308+
cand_name = f"candidate-{resp_idx}"
309+
if 0 <= resp_idx < len(candidate_names):
310+
cand_name = candidate_names[resp_idx]
311+
312+
metric_results = getv(resp_cand_result, ["metric_results"]) or {}
313+
314+
for metric_name, metric_res in metric_results.items():
315+
api_rubric_verdicts: list[dict[str, Any]] = []
316+
rubric_verdicts = getv(metric_res, ["rubric_verdicts"]) or []
317+
318+
for verdict in rubric_verdicts:
319+
verdict_dict: dict[str, Any] = {}
320+
eval_rubric = getv(verdict, ["evaluated_rubric"])
321+
322+
if eval_rubric:
323+
rubric_content = getv(eval_rubric, ["content"])
324+
if rubric_content:
325+
text = getv(rubric_content, ["text"])
326+
prop = getv(rubric_content, ["property"])
327+
328+
content_dict: dict[str, Any] = {}
329+
if text:
330+
content_dict["text"] = str(text)
331+
if prop:
332+
desc = getv(prop, ["description"])
333+
if desc:
334+
content_dict["property"] = {
335+
"description": str(desc)
336+
}
337+
verdict_dict["evaluated_rubric"] = {
338+
"content": content_dict
339+
}
340+
341+
score = getv(verdict, ["score"])
342+
if score is not None:
343+
verdict_dict["score"] = float(score)
344+
345+
explanation = getv(verdict, ["explanation"])
346+
if explanation:
347+
verdict_dict["explanation"] = str(explanation)
348+
349+
if verdict_dict:
350+
api_rubric_verdicts.append(verdict_dict)
351+
352+
score = getv(metric_res, ["score"])
353+
explanation = getv(metric_res, ["explanation"])
354+
355+
candidate_result_payload: dict[str, Any] = {
356+
"candidate": str(cand_name),
357+
"metric": str(metric_name),
358+
}
359+
if score is not None:
360+
candidate_result_payload["score"] = float(score)
361+
if explanation:
362+
candidate_result_payload["explanation"] = str(explanation)
363+
if api_rubric_verdicts:
364+
candidate_result_payload["rubric_verdicts"] = (
365+
api_rubric_verdicts
366+
)
367+
368+
api_eval_result = {
369+
"request": {"prompt": prompt_payload},
370+
"metric": str(metric_name),
371+
"candidate_results": [candidate_result_payload],
372+
}
373+
api_results.append(api_eval_result)
374+
375+
return api_results

vertexai/_genai/evals.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,36 @@ def _GenerateInstanceRubricsRequest_to_vertex(
595595
return to_object
596596

597597

598+
def _GenerateLossClustersParameters_to_vertex(
599+
from_object: Union[dict[str, Any], object],
600+
parent_object: Optional[dict[str, Any]] = None,
601+
) -> dict[str, Any]:
602+
to_object: dict[str, Any] = {}
603+
if getv(from_object, ["location"]) is not None:
604+
setv(to_object, ["location"], getv(from_object, ["location"]))
605+
606+
if getv(from_object, ["evaluation_set"]) is not None:
607+
setv(to_object, ["evaluationSet"], getv(from_object, ["evaluation_set"]))
608+
609+
if getv(from_object, ["inline_results"]) is not None:
610+
setv(
611+
to_object,
612+
["inlineResults", "evaluationResults"],
613+
[
614+
item
615+
for item in t.t_inline_results(getv(from_object, ["inline_results"]))
616+
],
617+
)
618+
619+
if getv(from_object, ["configs"]) is not None:
620+
setv(to_object, ["configs"], [item for item in getv(from_object, ["configs"])])
621+
622+
if getv(from_object, ["config"]) is not None:
623+
setv(to_object, ["config"], getv(from_object, ["config"]))
624+
625+
return to_object
626+
627+
598628
def _GenerateUserScenariosParameters_to_vertex(
599629
from_object: Union[dict[str, Any], object],
600630
parent_object: Optional[dict[str, Any]] = None,
@@ -1268,6 +1298,65 @@ def _generate_user_scenarios(
12681298
self._api_client._verify_response(return_value)
12691299
return return_value
12701300

1301+
def _generate_loss_clusters(
1302+
self,
1303+
*,
1304+
location: Optional[str] = None,
1305+
evaluation_set: Optional[str] = None,
1306+
inline_results: Optional[list[types.EvaluationResultOrDict]] = None,
1307+
configs: Optional[list[types.LossAnalysisConfigOrDict]] = None,
1308+
config: Optional[types.GenerateLossClustersConfigOrDict] = None,
1309+
) -> types.GenerateLossClustersOperation:
1310+
"""
1311+
Generates loss clusters from evaluation results.
1312+
"""
1313+
1314+
parameter_model = types._GenerateLossClustersParameters(
1315+
location=location,
1316+
evaluation_set=evaluation_set,
1317+
inline_results=inline_results,
1318+
configs=configs,
1319+
config=config,
1320+
)
1321+
1322+
request_url_dict: Optional[dict[str, str]]
1323+
if not self._api_client.vertexai:
1324+
raise ValueError("This method is only supported in the Vertex AI client.")
1325+
else:
1326+
request_dict = _GenerateLossClustersParameters_to_vertex(parameter_model)
1327+
request_url_dict = request_dict.get("_url")
1328+
if request_url_dict:
1329+
path = ":generateLossClusters".format_map(request_url_dict)
1330+
else:
1331+
path = ":generateLossClusters"
1332+
1333+
query_params = request_dict.get("_query")
1334+
if query_params:
1335+
path = f"{path}?{urlencode(query_params)}"
1336+
# TODO: remove the hack that pops config.
1337+
request_dict.pop("config", None)
1338+
1339+
http_options: Optional[types.HttpOptions] = None
1340+
if (
1341+
parameter_model.config is not None
1342+
and parameter_model.config.http_options is not None
1343+
):
1344+
http_options = parameter_model.config.http_options
1345+
1346+
request_dict = _common.convert_to_dict(request_dict)
1347+
request_dict = _common.encode_unserializable_types(request_dict)
1348+
1349+
response = self._api_client.request("post", path, request_dict, http_options)
1350+
1351+
response_dict = {} if not response.body else json.loads(response.body)
1352+
1353+
return_value = types.GenerateLossClustersOperation._from_response(
1354+
response=response_dict, kwargs=parameter_model.model_dump()
1355+
)
1356+
1357+
self._api_client._verify_response(return_value)
1358+
return return_value
1359+
12711360
def _generate_rubrics(
12721361
self,
12731362
*,
@@ -2833,6 +2922,67 @@ async def _generate_user_scenarios(
28332922
self._api_client._verify_response(return_value)
28342923
return return_value
28352924

2925+
async def _generate_loss_clusters(
2926+
self,
2927+
*,
2928+
location: Optional[str] = None,
2929+
evaluation_set: Optional[str] = None,
2930+
inline_results: Optional[list[types.EvaluationResultOrDict]] = None,
2931+
configs: Optional[list[types.LossAnalysisConfigOrDict]] = None,
2932+
config: Optional[types.GenerateLossClustersConfigOrDict] = None,
2933+
) -> types.GenerateLossClustersOperation:
2934+
"""
2935+
Generates loss clusters from evaluation results.
2936+
"""
2937+
2938+
parameter_model = types._GenerateLossClustersParameters(
2939+
location=location,
2940+
evaluation_set=evaluation_set,
2941+
inline_results=inline_results,
2942+
configs=configs,
2943+
config=config,
2944+
)
2945+
2946+
request_url_dict: Optional[dict[str, str]]
2947+
if not self._api_client.vertexai:
2948+
raise ValueError("This method is only supported in the Vertex AI client.")
2949+
else:
2950+
request_dict = _GenerateLossClustersParameters_to_vertex(parameter_model)
2951+
request_url_dict = request_dict.get("_url")
2952+
if request_url_dict:
2953+
path = ":generateLossClusters".format_map(request_url_dict)
2954+
else:
2955+
path = ":generateLossClusters"
2956+
2957+
query_params = request_dict.get("_query")
2958+
if query_params:
2959+
path = f"{path}?{urlencode(query_params)}"
2960+
# TODO: remove the hack that pops config.
2961+
request_dict.pop("config", None)
2962+
2963+
http_options: Optional[types.HttpOptions] = None
2964+
if (
2965+
parameter_model.config is not None
2966+
and parameter_model.config.http_options is not None
2967+
):
2968+
http_options = parameter_model.config.http_options
2969+
2970+
request_dict = _common.convert_to_dict(request_dict)
2971+
request_dict = _common.encode_unserializable_types(request_dict)
2972+
2973+
response = await self._api_client.async_request(
2974+
"post", path, request_dict, http_options
2975+
)
2976+
2977+
response_dict = {} if not response.body else json.loads(response.body)
2978+
2979+
return_value = types.GenerateLossClustersOperation._from_response(
2980+
response=response_dict, kwargs=parameter_model.model_dump()
2981+
)
2982+
2983+
self._api_client._verify_response(return_value)
2984+
return return_value
2985+
28362986
async def _generate_rubrics(
28372987
self,
28382988
*,

0 commit comments

Comments
 (0)