|
1 | 1 | import importlib.util |
2 | 2 | import re |
3 | 3 | import sys |
4 | | -from unittest.mock import MagicMock |
| 4 | +from unittest.mock import MagicMock, patch |
5 | 5 |
|
6 | 6 | import pytest |
7 | 7 | from braintrust.logger import BraintrustState |
@@ -78,6 +78,62 @@ def exact_match(input_value, output, expected): |
78 | 78 | assert result.summary.scores["exact_match"].score == 1.0 |
79 | 79 |
|
80 | 80 |
|
| 81 | +@pytest.mark.asyncio |
| 82 | +async def test_run_evaluator_forwards_base_experiment_id_to_summary(with_memory_logger, with_simulate_login): |
| 83 | + def exact_match(input_value, output, expected): |
| 84 | + return 1.0 if output == expected else 0.0 |
| 85 | + |
| 86 | + evaluator = Evaluator( |
| 87 | + project_name="test-project", |
| 88 | + eval_name="test-evaluator", |
| 89 | + data=[EvalCase(input=1, expected=1)], |
| 90 | + task=lambda input_value: input_value, |
| 91 | + scores=[exact_match], |
| 92 | + experiment_name=None, |
| 93 | + metadata=None, |
| 94 | + base_experiment_id="base-exp-id", |
| 95 | + ) |
| 96 | + |
| 97 | + exp = init_test_exp("test-evaluator", "test-project") |
| 98 | + expected_summary = MagicMock() |
| 99 | + exp.summarize = MagicMock(return_value=expected_summary) |
| 100 | + |
| 101 | + result = await run_evaluator(experiment=exp, evaluator=evaluator, position=None, filters=[]) |
| 102 | + |
| 103 | + assert result.summary is expected_summary |
| 104 | + exp.summarize.assert_called_once_with( |
| 105 | + summarize_scores=True, |
| 106 | + comparison_experiment_id="base-exp-id", |
| 107 | + ) |
| 108 | + |
| 109 | + |
| 110 | +def test_experiment_summarize_resolves_explicit_comparison_name(with_memory_logger, with_simulate_login): |
| 111 | + exp = init_test_exp("test-evaluator", "test-project") |
| 112 | + mock_conn = MagicMock() |
| 113 | + |
| 114 | + def get_json(path, args=None): |
| 115 | + if path == "v1/experiment/base-exp-id": |
| 116 | + return {"name": "base-exp"} |
| 117 | + if path == "experiment-comparison2": |
| 118 | + return {"scores": {}, "metrics": {}} |
| 119 | + raise AssertionError(f"Unexpected get_json call: {path}, {args}") |
| 120 | + |
| 121 | + mock_conn.get_json.side_effect = get_json |
| 122 | + |
| 123 | + with patch.object(exp.state, "api_conn", return_value=mock_conn): |
| 124 | + summary = exp.summarize(comparison_experiment_id="base-exp-id") |
| 125 | + |
| 126 | + assert summary.comparison_experiment_name == "base-exp" |
| 127 | + mock_conn.get_json.assert_any_call("v1/experiment/base-exp-id") |
| 128 | + mock_conn.get_json.assert_any_call( |
| 129 | + "experiment-comparison2", |
| 130 | + args={ |
| 131 | + "experiment_id": "test-evaluator", |
| 132 | + "base_experiment_id": "base-exp-id", |
| 133 | + }, |
| 134 | + ) |
| 135 | + |
| 136 | + |
81 | 137 | @pytest.mark.asyncio |
82 | 138 | @pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed") |
83 | 139 | async def test_run_evaluator_exposes_validated_parameter_values_to_hooks(): |
|
0 commit comments