Skip to content

Commit 76ab098

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Add a warning that the autorater config is not applicable for predefined metrics in SDK
FUTURE_COPYBARA_INTEGRATE_REVIEW=#6808 from googleapis:release-please--branches--main bcc9533 PiperOrigin-RevId: 910847685
1 parent 7a68393 commit 76ab098

5 files changed

Lines changed: 401 additions & 341 deletions

File tree

agentplatform/_genai/_evals_metric_handlers.py

Lines changed: 76 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -960,88 +960,100 @@ def aggregate(
960960

961961

962962
class PredefinedMetricHandler(MetricHandler[types.Metric]):
963-
"""Metric handler for predefined metrics."""
963+
"""Metric handler for predefined metrics."""
964964

965-
@property
966-
def metric_name(self) -> str:
967-
return self.metric.name or "unknown_metric"
965+
@property
966+
def metric_name(self) -> str:
967+
return self.metric.name or "unknown_metric"
968968

969-
def __init__(self, module: "evals.Evals", metric: types.Metric):
970-
super().__init__(module=module, metric=metric)
971-
if self.metric.name not in _evals_constant.SUPPORTED_PREDEFINED_METRICS:
972-
raise ValueError(
969+
def __init__(self, module: "evals.Evals", metric: types.Metric):
970+
super().__init__(module=module, metric=metric)
971+
if self.metric.name not in _evals_constant.SUPPORTED_PREDEFINED_METRICS:
972+
raise ValueError(
973973
f"Metric '{self.metric.name}' is not a supported predefined metric."
974974
)
975-
976-
def _build_request_payload(
977-
self, eval_case: types.EvalCase, response_index: int
978-
) -> dict[str, Any]:
979-
"""Builds the request parameters for evaluate instances request."""
980-
response_content = _get_response_from_eval_case(
975+
if (
976+
self.metric.judge_model
977+
or self.metric.judge_model_generation_config
978+
or self.metric.judge_model_sampling_count
979+
):
980+
logger.warning(
981+
"Autorater config settings (judge_model, "
982+
"judge_model_generation_config, judge_model_sampling_count) "
983+
"are ignored for predefined metric '%s'.",
984+
self.metric.name,
985+
)
986+
987+
def _build_request_payload(
988+
self, eval_case: types.EvalCase, response_index: int
989+
) -> dict[str, Any]:
990+
"""Builds the request parameters for evaluate instances request."""
991+
response_content = _get_response_from_eval_case(
981992
eval_case, response_index, self.metric.name
982993
)
983994

984-
if not response_content and not getattr(eval_case, "agent_data", None):
985-
raise ValueError(
995+
if not response_content and not getattr(eval_case, "agent_data", None):
996+
raise ValueError(
986997
f"Response content missing for candidate {response_index}."
987998
)
988999

989-
if self.metric.name == "tool_use_quality_v1":
990-
has_tool_call = _has_tool_call(eval_case.intermediate_events)
1000+
if self.metric.name == "tool_use_quality_v1":
1001+
has_tool_call = _has_tool_call(eval_case.intermediate_events)
9911002

992-
# Check agent_data for tool calls if intermediate_events is empty
993-
agent_data = getattr(eval_case, "agent_data", None)
994-
if not has_tool_call and agent_data:
995-
for turn in agent_data.turns or []:
996-
if _has_tool_call(turn.events):
997-
has_tool_call = True
998-
break
1003+
# Check agent_data for tool calls if intermediate_events is empty
1004+
agent_data = getattr(eval_case, "agent_data", None)
1005+
if not has_tool_call and agent_data:
1006+
for turn in agent_data.turns or []:
1007+
if _has_tool_call(turn.events):
1008+
has_tool_call = True
1009+
break
9991010

1000-
if not has_tool_call:
1001-
logger.warning(
1011+
if not has_tool_call:
1012+
logger.warning(
10021013
"Metric 'tool_use_quality_v1' requires tool usage in "
10031014
"'intermediate_events' or 'agent_data', but no tool usage was found for case %s.",
10041015
eval_case.eval_case_id,
10051016
)
10061017

1007-
extracted_prompt = _get_prompt_from_eval_case(eval_case)
1008-
prompt_instance_data = None
1009-
if self.metric.name and self.metric.name.startswith("multi_turn"):
1010-
prompt_contents = [
1018+
extracted_prompt = _get_prompt_from_eval_case(eval_case)
1019+
prompt_instance_data = None
1020+
if self.metric.name and self.metric.name.startswith("multi_turn"):
1021+
prompt_contents = [
10111022
msg.content for msg in (eval_case.conversation_history or [])
10121023
]
1013-
if extracted_prompt:
1014-
prompt_contents.append(extracted_prompt)
1015-
prompt_instance_data = types.evals.InstanceData(
1024+
if extracted_prompt:
1025+
prompt_contents.append(extracted_prompt)
1026+
prompt_instance_data = types.evals.InstanceData(
10161027
contents=types.evals.InstanceDataContents(contents=prompt_contents)
10171028
)
10181029

1019-
instance_payload = _build_evaluation_instance(
1030+
instance_payload = _build_evaluation_instance(
10201031
eval_case=eval_case,
10211032
response_content=response_content,
10221033
prompt_instance_data=prompt_instance_data,
10231034
)
10241035

1025-
request_payload: dict[str, Any] = {
1036+
request_payload: dict[str, Any] = {
10261037
"instance": instance_payload,
10271038
}
10281039

1029-
autorater_config = _get_autorater_config(self.metric)
1030-
if autorater_config:
1031-
request_payload["autorater_config"] = genai_types.AutoraterConfig(
1040+
autorater_config = _get_autorater_config(self.metric)
1041+
if autorater_config:
1042+
request_payload["autorater_config"] = genai_types.AutoraterConfig(
10321043
**autorater_config
10331044
)
1034-
return request_payload
10351045

1036-
@override
1037-
def get_metric_result(
1046+
return request_payload
1047+
1048+
@override
1049+
def get_metric_result(
10381050
self, eval_case: types.EvalCase, response_index: int
10391051
) -> types.EvalCaseMetricResult:
1040-
"""Processes a single evaluation case for a specific predefined metric."""
1041-
metric_name = self.metric.name
1042-
try:
1043-
payload = self._build_request_payload(eval_case, response_index)
1044-
api_response = _call_with_retry(
1052+
"""Processes a single evaluation case for a specific predefined metric."""
1053+
metric_name = self.metric.name
1054+
try:
1055+
payload = self._build_request_payload(eval_case, response_index)
1056+
api_response = _call_with_retry(
10451057
lambda: self.module._evaluate_instances(
10461058
metrics=[self.metric],
10471059
instance=payload.get("instance"),
@@ -1050,25 +1062,25 @@ def get_metric_result(
10501062
metric_name,
10511063
)
10521064

1053-
if (
1065+
if (
10541066
api_response
10551067
and hasattr(api_response, "metric_results")
10561068
and api_response.metric_results
10571069
):
1058-
result_data = api_response.metric_results[0]
1070+
result_data = api_response.metric_results[0]
10591071

1060-
error_message = None
1061-
if result_data.error and getattr(result_data.error, "code"):
1062-
error_message = f"Error in metric result: {result_data.error}"
1063-
return types.EvalCaseMetricResult(
1072+
error_message = None
1073+
if result_data.error and getattr(result_data.error, "code"):
1074+
error_message = f"Error in metric result: {result_data.error}"
1075+
return types.EvalCaseMetricResult(
10641076
metric_name=metric_name,
10651077
score=result_data.score,
10661078
explanation=result_data.explanation,
10671079
rubric_verdicts=result_data.rubric_verdicts,
10681080
error_message=error_message,
10691081
)
1070-
else:
1071-
logger.error(
1082+
else:
1083+
logger.error(
10721084
"Metric results missing in API response for predefined metric '%s'."
10731085
" API response: %s",
10741086
metric_name,
@@ -1078,29 +1090,29 @@ def get_metric_result(
10781090
else "None"
10791091
),
10801092
)
1081-
return types.EvalCaseMetricResult(
1093+
return types.EvalCaseMetricResult(
10821094
metric_name=metric_name,
10831095
error_message="Metric results missing in API response.",
10841096
)
1085-
except Exception as e: # pylint: disable=broad-exception-caught
1086-
logger.error(
1097+
except Exception as e: # pylint: disable=broad-exception-caught
1098+
logger.error(
10871099
"Error processing metric %s for case %s: %s",
10881100
metric_name,
10891101
eval_case.eval_case_id,
10901102
e,
10911103
exc_info=True,
10921104
)
1093-
return types.EvalCaseMetricResult(
1105+
return types.EvalCaseMetricResult(
10941106
metric_name=metric_name, error_message=str(e)
10951107
)
10961108

1097-
@override
1098-
def aggregate(
1109+
@override
1110+
def aggregate(
10991111
self, eval_case_metric_results: list[types.EvalCaseMetricResult]
11001112
) -> types.AggregatedMetricResult:
1101-
"""Aggregates the metric results for a predefined metric."""
1102-
logger.debug("Aggregating results for predefined metric: %s", self.metric.name)
1103-
return _default_aggregate_scores(
1113+
"""Aggregates the metric results for a predefined metric."""
1114+
logger.debug("Aggregating results for predefined metric: %s", self.metric.name)
1115+
return _default_aggregate_scores(
11041116
self.metric.name, eval_case_metric_results, calculate_pass_rate=True
11051117
)
11061118

0 commit comments

Comments
 (0)