Skip to content

Commit ad36123

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Eval SDK: Migrate model call method by genai SDK usage in preview foler
PiperOrigin-RevId: 893099457
1 parent 3de2c1e commit ad36123

File tree

9 files changed

+205
-14
lines changed

9 files changed

+205
-14
lines changed

tests/unit/vertexai/test_evaluation.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2153,6 +2153,29 @@ def test_evaluate_invalid_metrics(self):
21532153
)
21542154
test_eval_task.evaluate()
21552155

2156+
@mock.patch("google.genai.Client")
2157+
def test_evaluate_model_genai(self, mock_client_class):
2158+
mock_client = mock.MagicMock()
2159+
mock_client.models.generate_content.return_value = mock.MagicMock(
2160+
text="test_response"
2161+
)
2162+
mock_client_class.return_value = mock_client
2163+
test_eval_task = EvalTaskPreview(
2164+
dataset=_TEST_EVAL_DATASET_WITHOUT_RESPONSE,
2165+
metrics=[PointwisePreview.SUMMARIZATION_QUALITY],
2166+
)
2167+
with mock.patch.object(
2168+
target=gapic_evaluation_services_preview.EvaluationServiceClient,
2169+
attribute="evaluate_instances",
2170+
side_effect=_MOCK_SUMMARIZATION_QUALITY_RESULT_PREVIEW,
2171+
):
2172+
test_result = test_eval_task.evaluate(
2173+
model="gemini-2.5-pro",
2174+
prompt_template="{instruction} test prompt template {context}",
2175+
)
2176+
assert mock_client.models.generate_content.call_count == 2
2177+
assert "summarization_quality/score" in test_result.metrics_table.columns
2178+
21562179
def test_evaluate_duplicate_string_metric(self):
21572180
metrics = [
21582181
"exact_match",

tests/unit/vertexai/test_rubric_based_eval.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,35 @@ def test_pointwise_instruction_following_metric(self):
215215
"rb_instruction_following/raw_outputs",
216216
]
217217

218+
@mock.patch("google.genai.Client")
219+
def test_pointwise_instruction_following_metric_genai(self, mock_client_class):
220+
import copy
221+
222+
metric = copy.deepcopy(PredefinedRubricMetrics.Pointwise.INSTRUCTION_FOLLOWING)
223+
metric.generation_config.model = "gemini-2.5-pro"
224+
mock_client = mock.MagicMock()
225+
mock_client.models.generate_content.return_value = mock.MagicMock(
226+
text="""```json{"questions": ["test_rubric"]}```"""
227+
)
228+
mock_client_class.return_value = mock_client
229+
with mock.patch.object(
230+
target=gapic_evaluation_services.EvaluationServiceClient,
231+
attribute="evaluate_instances",
232+
side_effect=_MOCK_POINTWISE_RESPONSE,
233+
):
234+
eval_result = EvalTask(
235+
dataset=_TEST_EVAL_DATASET, metrics=[metric]
236+
).evaluate()
237+
assert eval_result.metrics_table.columns.tolist() == [
238+
"prompt",
239+
"response",
240+
"rubrics",
241+
"rb_instruction_following/score",
242+
"rb_instruction_following/rubric_verdict_pairs",
243+
"rb_instruction_following/raw_outputs",
244+
]
245+
assert mock_client.models.generate_content.call_count == 3
246+
218247
def test_pairwise_instruction_following_metric(self):
219248
metric = PredefinedRubricMetrics.Pairwise.INSTRUCTION_FOLLOWING
220249
mock_model = mock.create_autospec(

vertexai/preview/evaluation/_evaluation.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
]
7575

7676
_RunnableType = Union[reasoning_engines.Queryable, Callable[[str], Dict[str, str]]]
77-
_ModelType = Union[generative_models.GenerativeModel, Callable[[str], str]]
77+
_ModelType = Union[str, generative_models.GenerativeModel, Callable[[str], str]]
7878

7979

8080
def _validate_metrics(metrics: List[Union[str, metrics_base._Metric]]) -> None:
@@ -399,6 +399,11 @@ def _run_model_inference(
399399
if constants.Dataset.PROMPT_COLUMN in evaluation_run_config.dataset.columns:
400400
t1 = time.perf_counter()
401401
if isinstance(model, generative_models.GenerativeModel):
402+
_LOGGER.warning(
403+
"vertexai.generative_models.GenerativeModel is deprecated for "
404+
"evaluation and will be removed in June 2026. Please pass a "
405+
"string model name instead."
406+
)
402407
responses = _pre_eval_utils._generate_responses_from_gemini_model(
403408
model, evaluation_run_config.dataset
404409
)
@@ -407,6 +412,15 @@ def _run_model_inference(
407412
evaluation_run_config,
408413
is_baseline_model,
409414
)
415+
elif isinstance(model, str):
416+
responses = _pre_eval_utils._generate_responses_from_genai_model(
417+
model, evaluation_run_config.dataset
418+
)
419+
_pre_eval_utils.populate_eval_dataset_with_model_responses(
420+
responses,
421+
evaluation_run_config,
422+
is_baseline_model,
423+
)
410424
elif callable(model):
411425
responses = _pre_eval_utils._generate_response_from_custom_model_fn(
412426
model, evaluation_run_config.dataset

vertexai/preview/evaluation/_pre_eval_utils.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from concurrent import futures
2222
from typing import Callable, Optional, Set, TYPE_CHECKING, Union, List
2323

24+
from google import genai
25+
from google.cloud import aiplatform
2426
from google.cloud.aiplatform import base
2527
from google.cloud.aiplatform_v1beta1.types import (
2628
content as gapic_content_types,
@@ -70,6 +72,107 @@ def _assemble_prompt(
7072
)
7173

7274

75+
def _generate_content_text_response_genai(
76+
model: str, client: genai.Client, prompt: str, max_retries: int = 3
77+
) -> str:
78+
"""Generates a text response from Gemini model from a text prompt with retries using genai module.
79+
80+
Args:
81+
model: The model name string.
82+
client: The genai client instance.
83+
prompt: The prompt to send to the model.
84+
max_retries: Maximum number of retries for response generation.
85+
86+
Returns:
87+
The text response from the model.
88+
Returns constants.RESPONSE_ERROR if there is an error after all retries.
89+
"""
90+
for retry_attempt in range(max_retries):
91+
try:
92+
response = client.models.generate_content(
93+
model=model,
94+
contents=prompt,
95+
)
96+
# The new SDK raises exceptions on blocked content instead of returning
97+
# block_reason directly, so if it succeeds, we can return the text.
98+
if response.text:
99+
return response.text
100+
else:
101+
_LOGGER.warning(
102+
"The model response was empty or blocked.\n"
103+
f"Prompt: {prompt}.\n"
104+
f"Retry attempt: {retry_attempt + 1}/{max_retries}"
105+
)
106+
except Exception as e: # pylint: disable=broad-except
107+
error_message = (
108+
f"Failed to generate response candidates from GenAI model "
109+
f"{model}.\n"
110+
f"Error: {e}.\n"
111+
f"Prompt: {prompt}.\n"
112+
f"Retry attempt: {retry_attempt + 1}/{max_retries}"
113+
)
114+
_LOGGER.warning(error_message)
115+
if retry_attempt < max_retries - 1:
116+
_LOGGER.info(
117+
f"Retrying response generation for prompt: {prompt}, attempt "
118+
f"{retry_attempt + 1}/{max_retries}..."
119+
)
120+
121+
final_error_message = (
122+
f"Failed to generate response from GenAI model {model}.\n" f"Prompt: {prompt}."
123+
)
124+
_LOGGER.warning(final_error_message)
125+
return constants.RESPONSE_ERROR
126+
127+
128+
def _generate_responses_from_genai_model(
129+
model: str,
130+
df: "pd.DataFrame",
131+
rubric_generation_prompt_template: Optional[str] = None,
132+
) -> List[str]:
133+
"""Generates responses from Google GenAI SDK for the given evaluation dataset."""
134+
_LOGGER.info(
135+
f"Generating a total of {df.shape[0]} "
136+
f"responses from Google GenAI model {model}."
137+
)
138+
tasks = []
139+
client = genai.Client(
140+
vertexai=True,
141+
project=aiplatform.initializer.global_config.project,
142+
location=aiplatform.initializer.global_config.location,
143+
)
144+
145+
with tqdm(total=len(df)) as pbar:
146+
with futures.ThreadPoolExecutor(max_workers=constants.MAX_WORKERS) as executor:
147+
for idx, row in df.iterrows():
148+
if rubric_generation_prompt_template:
149+
input_columns = prompt_template_base.PromptTemplate(
150+
rubric_generation_prompt_template
151+
).variables
152+
if multimodal_utils.is_multimodal_instance(
153+
row[list(input_columns)].to_dict()
154+
):
155+
prompt = multimodal_utils._assemble_multi_modal_prompt(
156+
rubric_generation_prompt_template, row, idx, input_columns
157+
)
158+
else:
159+
prompt = _assemble_prompt(
160+
row, rubric_generation_prompt_template
161+
)
162+
else:
163+
prompt = row[constants.Dataset.PROMPT_COLUMN]
164+
task = executor.submit(
165+
_generate_content_text_response_genai,
166+
prompt=prompt,
167+
model=model,
168+
client=client,
169+
)
170+
task.add_done_callback(lambda _: pbar.update(1))
171+
tasks.append(task)
172+
responses = [future.result() for future in tasks]
173+
return responses
174+
175+
73176
def _generate_content_text_response(
74177
model: generative_models.GenerativeModel, prompt: str, max_attempts: int = 3
75178
) -> str:

vertexai/preview/evaluation/eval_task.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
GenerativeModel = generative_models.GenerativeModel
6464

6565
_RunnableType = Union[reasoning_engines.Queryable, Callable[[str], Dict[str, str]]]
66-
_ModelType = Union[generative_models.GenerativeModel, Callable[[str], str]]
66+
_ModelType = Union[str, generative_models.GenerativeModel, Callable[[str], str]]
6767

6868

6969
class EvalTask:
@@ -579,6 +579,12 @@ def _log_eval_experiment_param(
579579
for category, threshold in safety_settings.items()
580580
}
581581
eval_metadata.update(safety_settings_as_str)
582+
elif isinstance(model, str):
583+
eval_metadata.update(
584+
{
585+
"model_name": model,
586+
}
587+
)
582588

583589
if runnable:
584590
if isinstance(runnable, reasoning_engines.LangchainAgent):

vertexai/preview/evaluation/metric_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _parse_required_inputs(
185185

186186
def load(
187187
file_path: str,
188-
baseline_model: Optional[Union[GenerativeModel, Callable[[str], str]]] = None,
188+
baseline_model: Optional[Union[str, GenerativeModel, Callable[[str], str]]] = None,
189189
) -> Union[PointwiseMetric, PairwiseMetric, RubricBasedMetric]:
190190
"""Loads a metric object from a YAML file.
191191
@@ -206,7 +206,7 @@ def load(
206206

207207
def loads(
208208
yaml_data: str,
209-
baseline_model: Optional[Union[GenerativeModel, Callable[[str], str]]] = None,
209+
baseline_model: Optional[Union[str, GenerativeModel, Callable[[str], str]]] = None,
210210
) -> Union[PointwiseMetric, PairwiseMetric, RubricBasedMetric]:
211211
"""Loads a metric object from YAML data.
212212

vertexai/preview/evaluation/metrics/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333

3434

35-
_ModelType = Union[generative_models.GenerativeModel, Callable[[str], str]]
35+
_ModelType = Union[str, generative_models.GenerativeModel, Callable[[str], str]]
3636

3737

3838
class _Metric(abc.ABC):

vertexai/preview/evaluation/metrics/pairwise_metric.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,13 @@
2626
from vertexai.preview.evaluation.metrics import (
2727
custom_output_config as custom_output_config_class,
2828
)
29+
from google.cloud.aiplatform import base
2930
from vertexai.preview.evaluation.metrics import (
3031
metric_prompt_template as metric_prompt_template_base,
3132
)
3233

34+
_LOGGER = base.Logger(__name__)
35+
3336

3437
class PairwiseMetric(_base._ModelBasedMetric): # pylint: disable=protected-access
3538
"""A Model-based Pairwise Metric.
@@ -64,8 +67,8 @@ class PairwiseMetric(_base._ModelBasedMetric): # pylint: disable=protected-acce
6467
Usage Examples:
6568
6669
```
67-
baseline_model = GenerativeModel("gemini-1.0-pro")
68-
candidate_model = GenerativeModel("gemini-1.5-pro")
70+
baseline_model = GenerativeModel("gemini-2.5-pro")
71+
candidate_model = GenerativeModel("gemini-2.5-flash")
6972
7073
pairwise_groundedness = PairwiseMetric(
7174
metric_prompt_template=MetricPromptTemplateExamples.get_prompt_template(
@@ -96,7 +99,7 @@ def __init__(
9699
metric_prompt_template_base.PairwiseMetricPromptTemplate, str
97100
],
98101
baseline_model: Optional[
99-
Union[generative_models.GenerativeModel, Callable[[str], str]]
102+
Union[str, generative_models.GenerativeModel, Callable[[str], str]]
100103
] = None,
101104
system_instruction: Optional[str] = None,
102105
autorater_config: Optional[gapic_eval_service_types.AutoraterConfig] = None,
@@ -124,6 +127,12 @@ def __init__(
124127
autorater_config=autorater_config,
125128
custom_output_config=custom_output_config,
126129
)
130+
if isinstance(baseline_model, generative_models.GenerativeModel):
131+
_LOGGER.warning(
132+
"vertexai.generative_models.GenerativeModel is deprecated for "
133+
"evaluation and will be removed in June 2026. Please pass a "
134+
"string model name instead."
135+
)
127136
self._baseline_model = baseline_model
128137

129138
@property

vertexai/preview/evaluation/metrics/rubric_based_metric.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
if TYPE_CHECKING:
3131
import pandas as pd
3232

33-
_DEFAULT_MODEL_NAME = "gemini-2.0-flash-001"
33+
_DEFAULT_MODEL_NAME = "gemini-2.5-pro"
3434
_LOGGER = base.Logger(__name__)
3535

3636

@@ -73,11 +73,18 @@ def generate_rubrics(
7373
)
7474
return eval_dataset
7575

76-
responses = _pre_eval_utils._generate_responses_from_gemini_model(
77-
model,
78-
eval_dataset,
79-
self.generation_config.prompt_template,
80-
)
76+
if isinstance(model, str):
77+
responses = _pre_eval_utils._generate_responses_from_genai_model(
78+
model,
79+
eval_dataset,
80+
self.generation_config.prompt_template,
81+
)
82+
else:
83+
responses = _pre_eval_utils._generate_responses_from_gemini_model(
84+
model,
85+
eval_dataset,
86+
self.generation_config.prompt_template,
87+
)
8188
if self.generation_config.parsing_fn:
8289
parsing_fn = self.generation_config.parsing_fn
8390
else:

0 commit comments

Comments
 (0)