Skip to content

Commit 9f97738

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Fix flaky evaluation test failures caused by thread execution order
PiperOrigin-RevId: 922227627
1 parent 8bb7bdf commit 9f97738

1 file changed

Lines changed: 52 additions & 6 deletions

File tree

tests/unit/vertexai/test_evaluation.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18+
import json
1819
import re
1920
import sys
2021
import threading
@@ -851,11 +852,26 @@ def test_compute_exact_match_metric(self, api_transport):
851852
)
852853
test_metrics = ["exact_match"]
853854
test_eval_task = EvalTask(dataset=eval_dataset, metrics=test_metrics)
854-
mock_metric_results = _MOCK_EXACT_MATCH_RESULT
855+
856+
def _exact_match_side_effect(**kwargs):
857+
request = kwargs.get("request")
858+
prediction = request.exact_match_input.instances[0].prediction
859+
reference = request.exact_match_input.instances[0].reference
860+
score = 1.0 if prediction == reference else 0.0
861+
return gapic_evaluation_service_types.EvaluateInstancesResponse(
862+
exact_match_results=gapic_evaluation_service_types.ExactMatchResults(
863+
exact_match_metric_values=[
864+
gapic_evaluation_service_types.ExactMatchMetricValue(
865+
score=score
866+
),
867+
]
868+
)
869+
)
870+
855871
with mock.patch.object(
856872
target=gapic_evaluation_services.EvaluationServiceClient,
857873
attribute="evaluate_instances",
858-
side_effect=mock_metric_results,
874+
side_effect=_exact_match_side_effect,
859875
):
860876
test_result = test_eval_task.evaluate()
861877

@@ -932,11 +948,26 @@ def test_compute_pointwise_metrics_free_string(self):
932948
metrics=[_TEST_POINTWISE_METRIC_FREE_STRING],
933949
metric_column_mapping={"abc": "prompt"},
934950
)
935-
mock_metric_results = _MOCK_POINTWISE_RESULT
951+
952+
def _pointwise_side_effect(**kwargs):
953+
request = kwargs.get("request")
954+
instance_data = json.loads(
955+
request.pointwise_metric_input.instance.json_instance
956+
)
957+
# Row with prompt "test_prompt" gets score 5, "text_prompt" gets 4.
958+
score = 5 if instance_data.get("abc") == "test_prompt" else 4
959+
return gapic_evaluation_service_types.EvaluateInstancesResponse(
960+
pointwise_metric_result=(
961+
gapic_evaluation_service_types.PointwiseMetricResult(
962+
score=score, explanation="explanation"
963+
)
964+
)
965+
)
966+
936967
with mock.patch.object(
937968
target=gapic_evaluation_services.EvaluationServiceClient,
938969
attribute="evaluate_instances",
939-
side_effect=mock_metric_results,
970+
side_effect=_pointwise_side_effect,
940971
):
941972
test_result = test_eval_task.evaluate()
942973

@@ -1095,11 +1126,26 @@ def test_compute_pointwise_metrics_without_model_inference(self, api_transport):
10951126
test_eval_task = EvalTask(
10961127
dataset=_TEST_EVAL_DATASET_ALL_INCLUDED, metrics=test_metrics
10971128
)
1098-
mock_metric_results = _MOCK_SUMMARIZATION_QUALITY_RESULT
1129+
1130+
def _summarization_side_effect(**kwargs):
1131+
request = kwargs.get("request")
1132+
instance_data = json.loads(
1133+
request.pointwise_metric_input.instance.json_instance
1134+
)
1135+
# Row with response "test" gets score 5, "text" gets score 4.
1136+
score = 5 if instance_data.get("response") == "test" else 4
1137+
return gapic_evaluation_service_types.EvaluateInstancesResponse(
1138+
pointwise_metric_result=(
1139+
gapic_evaluation_service_types.PointwiseMetricResult(
1140+
score=score, explanation="explanation"
1141+
)
1142+
)
1143+
)
1144+
10991145
with mock.patch.object(
11001146
target=gapic_evaluation_services.EvaluationServiceClient,
11011147
attribute="evaluate_instances",
1102-
side_effect=mock_metric_results,
1148+
side_effect=_summarization_side_effect,
11031149
):
11041150
test_result = test_eval_task.evaluate()
11051151

0 commit comments

Comments
 (0)