Skip to content

Commit ff5e246

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Eval SDK: Migrate model call method by genai SDK usage
PiperOrigin-RevId: 893061063
1 parent bea67c2 commit ff5e246

File tree

4 files changed

+249
-25
lines changed

4 files changed

+249
-25
lines changed

tests/unit/vertexai/test_evaluation.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from unittest import mock
2323

2424
from google import auth
25+
from google import genai
2526
from google.auth import credentials as auth_credentials
2627
from google.cloud import aiplatform
2728
import vertexai
@@ -1025,6 +1026,62 @@ def test_compute_pointwise_metrics_metric_prompt_template_example(
10251026
"explanation",
10261027
]
10271028

1029+
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
1030+
def test_compute_pointwise_metrics_metric_prompt_template_example_string_model(
1031+
self, api_transport
1032+
):
1033+
aiplatform.init(
1034+
project=_TEST_PROJECT,
1035+
location=_TEST_LOCATION,
1036+
api_transport=api_transport,
1037+
)
1038+
mock_client = mock.create_autospec(genai.Client, instance=True)
1039+
mock_response = mock.MagicMock()
1040+
mock_response.text = "test_response"
1041+
mock_client.models.generate_content.return_value = mock_response
1042+
1043+
test_metrics = [Pointwise.SUMMARIZATION_QUALITY]
1044+
test_eval_task = EvalTask(
1045+
dataset=_TEST_EVAL_DATASET_WITHOUT_RESPONSE, metrics=test_metrics
1046+
)
1047+
mock_metric_results = _MOCK_SUMMARIZATION_QUALITY_RESULT
1048+
with mock.patch.object(genai, "Client", return_value=mock_client):
1049+
with mock.patch.object(
1050+
target=gapic_evaluation_services.EvaluationServiceClient,
1051+
attribute="evaluate_instances",
1052+
side_effect=mock_metric_results,
1053+
):
1054+
test_result = test_eval_task.evaluate(
1055+
model="gemini-1.5-pro",
1056+
prompt_template="{instruction} test prompt template {context}",
1057+
)
1058+
1059+
assert test_result.summary_metrics["row_count"] == 2
1060+
assert test_result.summary_metrics["summarization_quality/mean"] == 4.5
1061+
assert test_result.summary_metrics[
1062+
"summarization_quality/std"
1063+
] == pytest.approx(0.7, 0.1)
1064+
assert set(test_result.metrics_table.columns.values) == set(
1065+
[
1066+
"context",
1067+
"instruction",
1068+
"reference",
1069+
"prompt",
1070+
"response",
1071+
"summarization_quality/score",
1072+
"summarization_quality/explanation",
1073+
]
1074+
)
1075+
assert list(
1076+
test_result.metrics_table["summarization_quality/score"].values
1077+
) == [5, 4]
1078+
assert list(
1079+
test_result.metrics_table["summarization_quality/explanation"].values
1080+
) == [
1081+
"explanation",
1082+
"explanation",
1083+
]
1084+
10281085
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
10291086
def test_compute_pointwise_metrics_without_model_inference(self, api_transport):
10301087
aiplatform.init(
@@ -1401,13 +1458,13 @@ def test_compute_multiple_metrics(self, api_transport):
14011458
mock_baseline_model.generate_content.return_value = (
14021459
_MOCK_MODEL_INFERENCE_RESPONSE
14031460
)
1404-
mock_baseline_model._model_name = "publishers/google/model/gemini-pro"
1461+
mock_baseline_model._model_name = "gemini-2.5-pro"
14051462
_TEST_PAIRWISE_METRIC._baseline_model = mock_baseline_model
14061463
mock_model = mock.create_autospec(
14071464
generative_models.GenerativeModel, instance=True
14081465
)
14091466
mock_model.generate_content.return_value = _MOCK_MODEL_INFERENCE_RESPONSE
1410-
mock_model._model_name = "publishers/google/model/gemini-pro"
1467+
mock_model._model_name = "gemini-2.5-flash"
14111468
test_metrics = [
14121469
"exact_match",
14131470
Pointwise.SUMMARIZATION_QUALITY,
@@ -2654,6 +2711,31 @@ def test_default_rubrics_parser_with_invalid_json(self):
26542711
parsed_rubrics = utils_preview.parse_rubrics(_INVALID_UNPARSED_RUBRIC)
26552712
assert parsed_rubrics == {"questions": ""}
26562713

2714+
def test_generate_responses_from_genai_model(self):
2715+
mock_client = mock.create_autospec(genai.Client, instance=True)
2716+
mock_response = mock.MagicMock()
2717+
mock_response.text = "test_response"
2718+
mock_client.models.generate_content.return_value = mock_response
2719+
2720+
with mock.patch.object(genai, "Client", return_value=mock_client):
2721+
evaluation_run_config = eval_base.EvaluationRunConfig(
2722+
dataset=_TEST_EVAL_DATASET_WITHOUT_RESPONSE.copy(),
2723+
metrics=[],
2724+
metric_column_mapping={},
2725+
client=mock.MagicMock(),
2726+
evaluation_service_qps=1,
2727+
retry_timeout=1,
2728+
)
2729+
_evaluation._generate_responses_from_genai_model(
2730+
"gemini-2.5-pro", evaluation_run_config
2731+
)
2732+
2733+
assert list(evaluation_run_config.dataset["response"].values) == [
2734+
"test_response",
2735+
"test_response",
2736+
]
2737+
assert mock_client.models.generate_content.call_count == 2
2738+
26572739
def test_generate_responses_from_gemini_model(self):
26582740
mock_model = mock.create_autospec(
26592741
generative_models.GenerativeModel, instance=True

vertexai/evaluation/_evaluation.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import time
2323
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
2424

25+
from google import genai
26+
from google.cloud import aiplatform
2527
from google.cloud.aiplatform import base
2628
from google.cloud.aiplatform_v1beta1.types import (
2729
content as gapic_content_types,
@@ -373,6 +375,106 @@ def _generate_content_text_response(
373375
return constants.RESPONSE_ERROR
374376

375377

378+
def _generate_content_text_response_genai(
379+
model: str, client: genai.Client, prompt: str, max_retries: int = 3
380+
) -> str:
381+
"""Generates a text response from Gemini model from a text prompt with retries using genai module.
382+
383+
Args:
384+
model: The model name string.
385+
client: The genai client instance.
386+
prompt: The prompt to send to the model.
387+
max_retries: Maximum number of retries for response generation.
388+
389+
Returns:
390+
The text response from the model.
391+
Returns constants.RESPONSE_ERROR if there is an error after all retries.
392+
"""
393+
for retry_attempt in range(max_retries):
394+
try:
395+
response = client.models.generate_content(
396+
model=model,
397+
contents=prompt,
398+
)
399+
# The new SDK raises exceptions on blocked content instead of returning
400+
# block_reason directly, so if it succeeds, we can return the text.
401+
if response.text:
402+
return response.text
403+
else:
404+
_LOGGER.warning(
405+
"The model response was empty or blocked.\n"
406+
f"Prompt: {prompt}.\n"
407+
f"Retry attempt: {retry_attempt + 1}/{max_retries}"
408+
)
409+
except Exception as e: # pylint: disable=broad-except
410+
error_message = (
411+
f"Failed to generate response candidates from GenAI model "
412+
f"{model}.\n"
413+
f"Error: {e}.\n"
414+
f"Prompt: {prompt}.\n"
415+
f"Retry attempt: {retry_attempt + 1}/{max_retries}"
416+
)
417+
_LOGGER.warning(error_message)
418+
if retry_attempt < max_retries - 1:
419+
_LOGGER.info(
420+
f"Retrying response generation for prompt: {prompt}, attempt "
421+
f"{retry_attempt + 1}/{max_retries}..."
422+
)
423+
424+
final_error_message = (
425+
f"Failed to generate response from GenAI model {model}.\n" f"Prompt: {prompt}."
426+
)
427+
_LOGGER.warning(final_error_message)
428+
return constants.RESPONSE_ERROR
429+
430+
431+
def _generate_responses_from_genai_model(
432+
model: str,
433+
evaluation_run_config: evaluation_base.EvaluationRunConfig,
434+
is_baseline_model: bool = False,
435+
) -> None:
436+
"""Generates responses from Gemini model using genai module.
437+
438+
Args:
439+
model: The model name string.
440+
evaluation_run_config: Evaluation Run Configurations.
441+
is_baseline_model: Whether the model is a baseline model for PairwiseMetric.
442+
"""
443+
df = evaluation_run_config.dataset.copy()
444+
445+
_LOGGER.info(
446+
f"Generating a total of {evaluation_run_config.dataset.shape[0]} "
447+
f"responses from Gemini model {model} using genai module."
448+
)
449+
tasks = []
450+
client = genai.Client(
451+
vertexai=True,
452+
project=aiplatform.initializer.global_config.project,
453+
location=aiplatform.initializer.global_config.location,
454+
)
455+
with tqdm(total=len(df)) as pbar:
456+
with futures.ThreadPoolExecutor(max_workers=constants.MAX_WORKERS) as executor:
457+
for _, row in df.iterrows():
458+
task = executor.submit(
459+
_generate_content_text_response_genai,
460+
prompt=row[constants.Dataset.PROMPT_COLUMN],
461+
model=model,
462+
client=client,
463+
)
464+
task.add_done_callback(lambda _: pbar.update(1))
465+
tasks.append(task)
466+
responses = [future.result() for future in tasks]
467+
if is_baseline_model:
468+
evaluation_run_config.dataset = df.assign(baseline_model_response=responses)
469+
else:
470+
evaluation_run_config.dataset = df.assign(response=responses)
471+
472+
_LOGGER.info(
473+
f"All {evaluation_run_config.dataset.shape[0]} responses are successfully "
474+
f"generated from Gemini model {model} using genai module."
475+
)
476+
477+
376478
def _generate_responses_from_gemini_model(
377479
model: generative_models.GenerativeModel,
378480
evaluation_run_config: evaluation_base.EvaluationRunConfig,
@@ -463,7 +565,7 @@ def _generate_response_from_custom_model_fn(
463565

464566

465567
def _run_model_inference(
466-
model: Union[generative_models.GenerativeModel, Callable[[str], str]],
568+
model: Union[str, generative_models.GenerativeModel, Callable[[str], str]],
467569
evaluation_run_config: evaluation_base.EvaluationRunConfig,
468570
response_column_name: str = constants.Dataset.MODEL_RESPONSE_COLUMN,
469571
) -> None:
@@ -488,9 +590,18 @@ def _run_model_inference(
488590
if constants.Dataset.PROMPT_COLUMN in evaluation_run_config.dataset.columns:
489591
t1 = time.perf_counter()
490592
if isinstance(model, generative_models.GenerativeModel):
593+
_LOGGER.warning(
594+
"vertexai.generative_models.GenerativeModel is deprecated for "
595+
"evaluation and will be removed in June 2026. Please pass a "
596+
"string model name instead."
597+
)
491598
_generate_responses_from_gemini_model(
492599
model, evaluation_run_config, is_baseline_model
493600
)
601+
elif isinstance(model, str):
602+
_generate_responses_from_genai_model(
603+
model, evaluation_run_config, is_baseline_model
604+
)
494605
elif callable(model):
495606
_generate_response_from_custom_model_fn(
496607
model, evaluation_run_config, is_baseline_model
@@ -878,7 +989,7 @@ def evaluate(
878989
metrics: List[Union[str, metrics_base._Metric]],
879990
*,
880991
model: Optional[
881-
Union[generative_models.GenerativeModel, Callable[[str], str]]
992+
Union[str, generative_models.GenerativeModel, Callable[[str], str]]
882993
] = None,
883994
prompt_template: Optional[Union[str, prompt_template_base.PromptTemplate]] = None,
884995
metric_column_mapping: Dict[str, str],

0 commit comments

Comments
 (0)