Skip to content

Commit bc7cfd9

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI Client(evals) - add core data models and code-gen mapping for auto-loss analysis
PiperOrigin-RevId: 892978545
1 parent 8710b02 commit bc7cfd9

File tree

5 files changed

+713
-0
lines changed

5 files changed

+713
-0
lines changed

tests/unit/vertexai/genai/test_evals.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
from vertexai._genai import _evals_metric_loaders
3535
from vertexai._genai import _gcs_utils
3636
from vertexai._genai import _observability_data_converter
37+
from vertexai._genai import _transformers
3738
from vertexai._genai import evals
3839
from vertexai._genai import types as vertexai_genai_types
40+
from vertexai._genai.types import common as common_types
3941
from google.genai import client
4042
from google.genai import errors as genai_errors
4143
from google.genai import types as genai_types
@@ -218,6 +220,51 @@ def test_get_api_client_with_none_location(
218220
mock_vertexai_client.assert_not_called()
219221

220222

223+
class TestTransformers:
224+
"""Unit tests for transformers."""
225+
226+
def test_t_inline_results(self):
227+
eval_result = common_types.EvaluationResult(
228+
eval_case_results=[
229+
common_types.EvalCaseResult(
230+
eval_case_index=0,
231+
response_candidate_results=[
232+
common_types.ResponseCandidateResult(
233+
response_index=0,
234+
metric_results={
235+
"tool_use_quality": common_types.EvalCaseMetricResult(
236+
score=0.0,
237+
explanation="Failed tool use",
238+
)
239+
},
240+
)
241+
],
242+
)
243+
],
244+
evaluation_dataset=[
245+
common_types.EvaluationDataset(
246+
eval_cases=[
247+
common_types.EvalCase(
248+
prompt=genai_types.Content(
249+
parts=[genai_types.Part(text="test prompt")]
250+
)
251+
)
252+
]
253+
)
254+
],
255+
metadata=common_types.EvaluationRunMetadata(candidate_names=["gemini-pro"]),
256+
)
257+
258+
payload = _transformers.t_inline_results([eval_result])
259+
260+
assert len(payload) == 1
261+
assert payload[0]["metric"] == "tool_use_quality"
262+
assert payload[0]["request"]["prompt"]["text"] == "test prompt"
263+
assert len(payload[0]["candidate_results"]) == 1
264+
assert payload[0]["candidate_results"][0]["candidate"] == "gemini-pro"
265+
assert payload[0]["candidate_results"][0]["score"] == 0.0
266+
267+
221268
class TestEvals:
222269
"""Unit tests for the GenAI client."""
223270

vertexai/_genai/_transformers.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from google.genai._common import get_value_by_path as getv
2121

2222
from . import _evals_constant
23+
from . import _evals_data_converters
2324
from . import types
2425

2526
_METRIC_RES_NAME_RE = r"^projects/[^/]+/locations/[^/]+/evaluationMetrics/[^/]+$"
@@ -257,3 +258,118 @@ def t_metric_for_registry(
257258
raise ValueError(f"Unsupported metric type: {metric_name}")
258259

259260
return metric_payload_item
261+
262+
263+
def t_inline_results(
264+
eval_results: list[Any],
265+
) -> list[dict[str, Any]]:
266+
"""Transforms a list of SDK EvaluationResults into API EvaluationResults."""
267+
api_results: list[dict[str, Any]] = []
268+
269+
for eval_result in eval_results:
270+
metadata = getv(eval_result, ["metadata"])
271+
candidate_names = getv(metadata, ["candidate_names"]) if metadata else []
272+
candidate_names = candidate_names or []
273+
274+
eval_dataset = getv(eval_result, ["evaluation_dataset"])
275+
eval_cases: list[Any] = []
276+
if isinstance(eval_dataset, list) and eval_dataset:
277+
eval_cases = getv(eval_dataset[0], ["eval_cases"]) or []
278+
279+
eval_case_results = getv(eval_result, ["eval_case_results"]) or []
280+
281+
for case_result in eval_case_results:
282+
case_idx = getv(case_result, ["eval_case_index"]) or 0
283+
284+
eval_case = None
285+
if 0 <= case_idx < len(eval_cases):
286+
eval_case = eval_cases[case_idx]
287+
288+
prompt_payload = {}
289+
if eval_case:
290+
agent_data = getv(eval_case, ["agent_data"])
291+
prompt = getv(eval_case, ["prompt"])
292+
293+
if agent_data:
294+
if hasattr(agent_data, "model_dump"):
295+
prompt_payload["agent_data"] = agent_data.model_dump()
296+
else:
297+
prompt_payload["agent_data"] = agent_data
298+
elif prompt:
299+
text = _evals_data_converters._get_content_text(
300+
prompt
301+
) # pylint: disable=protected-access
302+
if text:
303+
prompt_payload["text"] = str(text)
304+
305+
cand_results = getv(case_result, ["response_candidate_results"]) or []
306+
for resp_cand_result in cand_results:
307+
resp_idx = getv(resp_cand_result, ["response_index"]) or 0
308+
cand_name = f"candidate-{resp_idx}"
309+
if 0 <= resp_idx < len(candidate_names):
310+
cand_name = candidate_names[resp_idx]
311+
312+
metric_results = getv(resp_cand_result, ["metric_results"]) or {}
313+
314+
for metric_name, metric_res in metric_results.items():
315+
api_rubric_verdicts: list[dict[str, Any]] = []
316+
rubric_verdicts = getv(metric_res, ["rubric_verdicts"]) or []
317+
318+
for verdict in rubric_verdicts:
319+
verdict_dict: dict[str, Any] = {}
320+
eval_rubric = getv(verdict, ["evaluated_rubric"])
321+
322+
if eval_rubric:
323+
rubric_content = getv(eval_rubric, ["content"])
324+
if rubric_content:
325+
text = getv(rubric_content, ["text"])
326+
prop = getv(rubric_content, ["property"])
327+
328+
content_dict: dict[str, Any] = {}
329+
if text:
330+
content_dict["text"] = str(text)
331+
if prop:
332+
desc = getv(prop, ["description"])
333+
if desc:
334+
content_dict["property"] = {
335+
"description": str(desc)
336+
}
337+
verdict_dict["evaluated_rubric"] = {
338+
"content": content_dict
339+
}
340+
341+
score = getv(verdict, ["score"])
342+
if score is not None:
343+
verdict_dict["score"] = float(score)
344+
345+
explanation = getv(verdict, ["explanation"])
346+
if explanation:
347+
verdict_dict["explanation"] = str(explanation)
348+
349+
if verdict_dict:
350+
api_rubric_verdicts.append(verdict_dict)
351+
352+
score = getv(metric_res, ["score"])
353+
explanation = getv(metric_res, ["explanation"])
354+
355+
candidate_result_payload: dict[str, Any] = {
356+
"candidate": str(cand_name),
357+
"metric": str(metric_name),
358+
}
359+
if score is not None:
360+
candidate_result_payload["score"] = float(score)
361+
if explanation:
362+
candidate_result_payload["explanation"] = str(explanation)
363+
if api_rubric_verdicts:
364+
candidate_result_payload["rubric_verdicts"] = (
365+
api_rubric_verdicts
366+
)
367+
368+
api_eval_result = {
369+
"request": {"prompt": prompt_payload},
370+
"metric": str(metric_name),
371+
"candidate_results": [candidate_result_payload],
372+
}
373+
api_results.append(api_eval_result)
374+
375+
return api_results

vertexai/_genai/evals.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,33 @@ def _GenerateInstanceRubricsRequest_to_vertex(
595595
return to_object
596596

597597

598+
def _GenerateLossClustersParameters_to_vertex(
599+
from_object: Union[dict[str, Any], object],
600+
parent_object: Optional[dict[str, Any]] = None,
601+
) -> dict[str, Any]:
602+
to_object: dict[str, Any] = {}
603+
if getv(from_object, ["location"]) is not None:
604+
setv(to_object, ["location"], getv(from_object, ["location"]))
605+
606+
if getv(from_object, ["evaluation_set"]) is not None:
607+
setv(to_object, ["evaluationSet"], getv(from_object, ["evaluation_set"]))
608+
609+
if getv(from_object, ["inline_results"]) is not None:
610+
setv(
611+
to_object,
612+
["inlineResults", "evaluationResults"],
613+
t.t_inline_results(getv(from_object, ["inline_results"])),
614+
)
615+
616+
if getv(from_object, ["configs"]) is not None:
617+
setv(to_object, ["configs"], [item for item in getv(from_object, ["configs"])])
618+
619+
if getv(from_object, ["config"]) is not None:
620+
setv(to_object, ["config"], getv(from_object, ["config"]))
621+
622+
return to_object
623+
624+
598625
def _GenerateUserScenariosParameters_to_vertex(
599626
from_object: Union[dict[str, Any], object],
600627
parent_object: Optional[dict[str, Any]] = None,
@@ -1268,6 +1295,65 @@ def _generate_user_scenarios(
12681295
self._api_client._verify_response(return_value)
12691296
return return_value
12701297

1298+
def _generate_loss_clusters(
1299+
self,
1300+
*,
1301+
location: Optional[str] = None,
1302+
evaluation_set: Optional[str] = None,
1303+
inline_results: Optional[list[types.EvaluationResultOrDict]] = None,
1304+
configs: Optional[list[types.LossAnalysisConfigOrDict]] = None,
1305+
config: Optional[types.GenerateLossClustersConfigOrDict] = None,
1306+
) -> types.GenerateLossClustersOperation:
1307+
"""
1308+
Generates loss clusters from evaluation results.
1309+
"""
1310+
1311+
parameter_model = types._GenerateLossClustersParameters(
1312+
location=location,
1313+
evaluation_set=evaluation_set,
1314+
inline_results=inline_results,
1315+
configs=configs,
1316+
config=config,
1317+
)
1318+
1319+
request_url_dict: Optional[dict[str, str]]
1320+
if not self._api_client.vertexai:
1321+
raise ValueError("This method is only supported in the Vertex AI client.")
1322+
else:
1323+
request_dict = _GenerateLossClustersParameters_to_vertex(parameter_model)
1324+
request_url_dict = request_dict.get("_url")
1325+
if request_url_dict:
1326+
path = ":generateLossClusters".format_map(request_url_dict)
1327+
else:
1328+
path = ":generateLossClusters"
1329+
1330+
query_params = request_dict.get("_query")
1331+
if query_params:
1332+
path = f"{path}?{urlencode(query_params)}"
1333+
# TODO: remove the hack that pops config.
1334+
request_dict.pop("config", None)
1335+
1336+
http_options: Optional[types.HttpOptions] = None
1337+
if (
1338+
parameter_model.config is not None
1339+
and parameter_model.config.http_options is not None
1340+
):
1341+
http_options = parameter_model.config.http_options
1342+
1343+
request_dict = _common.convert_to_dict(request_dict)
1344+
request_dict = _common.encode_unserializable_types(request_dict)
1345+
1346+
response = self._api_client.request("post", path, request_dict, http_options)
1347+
1348+
response_dict = {} if not response.body else json.loads(response.body)
1349+
1350+
return_value = types.GenerateLossClustersOperation._from_response(
1351+
response=response_dict, kwargs=parameter_model.model_dump()
1352+
)
1353+
1354+
self._api_client._verify_response(return_value)
1355+
return return_value
1356+
12711357
def _generate_rubrics(
12721358
self,
12731359
*,
@@ -2833,6 +2919,67 @@ async def _generate_user_scenarios(
28332919
self._api_client._verify_response(return_value)
28342920
return return_value
28352921

2922+
async def _generate_loss_clusters(
2923+
self,
2924+
*,
2925+
location: Optional[str] = None,
2926+
evaluation_set: Optional[str] = None,
2927+
inline_results: Optional[list[types.EvaluationResultOrDict]] = None,
2928+
configs: Optional[list[types.LossAnalysisConfigOrDict]] = None,
2929+
config: Optional[types.GenerateLossClustersConfigOrDict] = None,
2930+
) -> types.GenerateLossClustersOperation:
2931+
"""
2932+
Generates loss clusters from evaluation results.
2933+
"""
2934+
2935+
parameter_model = types._GenerateLossClustersParameters(
2936+
location=location,
2937+
evaluation_set=evaluation_set,
2938+
inline_results=inline_results,
2939+
configs=configs,
2940+
config=config,
2941+
)
2942+
2943+
request_url_dict: Optional[dict[str, str]]
2944+
if not self._api_client.vertexai:
2945+
raise ValueError("This method is only supported in the Vertex AI client.")
2946+
else:
2947+
request_dict = _GenerateLossClustersParameters_to_vertex(parameter_model)
2948+
request_url_dict = request_dict.get("_url")
2949+
if request_url_dict:
2950+
path = ":generateLossClusters".format_map(request_url_dict)
2951+
else:
2952+
path = ":generateLossClusters"
2953+
2954+
query_params = request_dict.get("_query")
2955+
if query_params:
2956+
path = f"{path}?{urlencode(query_params)}"
2957+
# TODO: remove the hack that pops config.
2958+
request_dict.pop("config", None)
2959+
2960+
http_options: Optional[types.HttpOptions] = None
2961+
if (
2962+
parameter_model.config is not None
2963+
and parameter_model.config.http_options is not None
2964+
):
2965+
http_options = parameter_model.config.http_options
2966+
2967+
request_dict = _common.convert_to_dict(request_dict)
2968+
request_dict = _common.encode_unserializable_types(request_dict)
2969+
2970+
response = await self._api_client.async_request(
2971+
"post", path, request_dict, http_options
2972+
)
2973+
2974+
response_dict = {} if not response.body else json.loads(response.body)
2975+
2976+
return_value = types.GenerateLossClustersOperation._from_response(
2977+
response=response_dict, kwargs=parameter_model.model_dump()
2978+
)
2979+
2980+
self._api_client._verify_response(return_value)
2981+
return return_value
2982+
28362983
async def _generate_rubrics(
28372984
self,
28382985
*,

0 commit comments

Comments
 (0)