|
15 | 15 | # limitations under the License. |
16 | 16 | # |
17 | 17 |
|
| 18 | +import json |
18 | 19 | import re |
19 | 20 | import sys |
20 | 21 | import threading |
@@ -851,11 +852,26 @@ def test_compute_exact_match_metric(self, api_transport): |
851 | 852 | ) |
852 | 853 | test_metrics = ["exact_match"] |
853 | 854 | test_eval_task = EvalTask(dataset=eval_dataset, metrics=test_metrics) |
854 | | - mock_metric_results = _MOCK_EXACT_MATCH_RESULT |
| 855 | + |
| 856 | + def _exact_match_side_effect(**kwargs): |
| 857 | + request = kwargs.get("request") |
| 858 | + prediction = request.exact_match_input.instances[0].prediction |
| 859 | + reference = request.exact_match_input.instances[0].reference |
| 860 | + score = 1.0 if prediction == reference else 0.0 |
| 861 | + return gapic_evaluation_service_types.EvaluateInstancesResponse( |
| 862 | + exact_match_results=gapic_evaluation_service_types.ExactMatchResults( |
| 863 | + exact_match_metric_values=[ |
| 864 | + gapic_evaluation_service_types.ExactMatchMetricValue( |
| 865 | + score=score |
| 866 | + ), |
| 867 | + ] |
| 868 | + ) |
| 869 | + ) |
| 870 | + |
855 | 871 | with mock.patch.object( |
856 | 872 | target=gapic_evaluation_services.EvaluationServiceClient, |
857 | 873 | attribute="evaluate_instances", |
858 | | - side_effect=mock_metric_results, |
| 874 | + side_effect=_exact_match_side_effect, |
859 | 875 | ): |
860 | 876 | test_result = test_eval_task.evaluate() |
861 | 877 |
|
@@ -932,11 +948,26 @@ def test_compute_pointwise_metrics_free_string(self): |
932 | 948 | metrics=[_TEST_POINTWISE_METRIC_FREE_STRING], |
933 | 949 | metric_column_mapping={"abc": "prompt"}, |
934 | 950 | ) |
935 | | - mock_metric_results = _MOCK_POINTWISE_RESULT |
| 951 | + |
| 952 | + def _pointwise_side_effect(**kwargs): |
| 953 | + request = kwargs.get("request") |
| 954 | + instance_data = json.loads( |
| 955 | + request.pointwise_metric_input.instance.json_instance |
| 956 | + ) |
| 957 | + # Row with prompt "test_prompt" gets score 5, "text_prompt" gets 4. |
| 958 | + score = 5 if instance_data.get("abc") == "test_prompt" else 4 |
| 959 | + return gapic_evaluation_service_types.EvaluateInstancesResponse( |
| 960 | + pointwise_metric_result=( |
| 961 | + gapic_evaluation_service_types.PointwiseMetricResult( |
| 962 | + score=score, explanation="explanation" |
| 963 | + ) |
| 964 | + ) |
| 965 | + ) |
| 966 | + |
936 | 967 | with mock.patch.object( |
937 | 968 | target=gapic_evaluation_services.EvaluationServiceClient, |
938 | 969 | attribute="evaluate_instances", |
939 | | - side_effect=mock_metric_results, |
| 970 | + side_effect=_pointwise_side_effect, |
940 | 971 | ): |
941 | 972 | test_result = test_eval_task.evaluate() |
942 | 973 |
|
@@ -1095,11 +1126,26 @@ def test_compute_pointwise_metrics_without_model_inference(self, api_transport): |
1095 | 1126 | test_eval_task = EvalTask( |
1096 | 1127 | dataset=_TEST_EVAL_DATASET_ALL_INCLUDED, metrics=test_metrics |
1097 | 1128 | ) |
1098 | | - mock_metric_results = _MOCK_SUMMARIZATION_QUALITY_RESULT |
| 1129 | + |
| 1130 | + def _summarization_side_effect(**kwargs): |
| 1131 | + request = kwargs.get("request") |
| 1132 | + instance_data = json.loads( |
| 1133 | + request.pointwise_metric_input.instance.json_instance |
| 1134 | + ) |
| 1135 | + # Row with response "test" gets score 5, "text" gets score 4. |
| 1136 | + score = 5 if instance_data.get("response") == "test" else 4 |
| 1137 | + return gapic_evaluation_service_types.EvaluateInstancesResponse( |
| 1138 | + pointwise_metric_result=( |
| 1139 | + gapic_evaluation_service_types.PointwiseMetricResult( |
| 1140 | + score=score, explanation="explanation" |
| 1141 | + ) |
| 1142 | + ) |
| 1143 | + ) |
| 1144 | + |
1099 | 1145 | with mock.patch.object( |
1100 | 1146 | target=gapic_evaluation_services.EvaluationServiceClient, |
1101 | 1147 | attribute="evaluate_instances", |
1102 | | - side_effect=mock_metric_results, |
| 1148 | + side_effect=_summarization_side_effect, |
1103 | 1149 | ): |
1104 | 1150 | test_result = test_eval_task.evaluate() |
1105 | 1151 |
|
|
0 commit comments