Skip to content

Commit 76ea5a6

Browse files
authored
fix(experiments): pass base_experiment_id to summarize (#512)
### Description Eval stores base_experiment_id correctly on the experiment but the final summary does not pass it as the explicit comparison ID. As a result, summary comparison can fall back to project/default baseline resolution and show wrong diffs. ### Fix Pass `evaluator.base_experiment_id` into `experiment.summarize(comparison_experiment_id=...)`, so score and metric diffs are computed against the explicit experiment baseline. Also resolve the explicit comparison experiment name so the returned summary displays the correct “compared to” name. Previously, `comparison_experiment_id` was `None`, so `summarize()` called `POST /api/base_experiment/get_id`; that resolver can apply UI/default-baseline behavior, including letting a project default baseline override the experiment’s explicit `base_exp_id`.
1 parent 28fae42 commit 76ea5a6

3 files changed

Lines changed: 108 additions & 2 deletions

File tree

py/src/braintrust/framework.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,14 @@ def _validate_classification_result(value: Any, classifier_name: str) -> Classif
13511351
return classification
13521352

13531353

1354+
def _get_persisted_base_experiment_id(experiment: Experiment) -> str | None:
1355+
try:
1356+
base_experiment_id = experiment.data.get("base_exp_id")
1357+
except Exception:
1358+
return None
1359+
return base_experiment_id if isinstance(base_experiment_id, str) and base_experiment_id else None
1360+
1361+
13541362
async def run_evaluator(
13551363
experiment: Experiment | None,
13561364
evaluator: Evaluator[Input, Output, Expected],
@@ -1367,7 +1375,13 @@ async def run_evaluator(
13671375
)
13681376

13691377
if experiment:
1370-
summary = experiment.summarize(summarize_scores=evaluator.summarize_scores)
1378+
comparison_experiment_id = evaluator.base_experiment_id
1379+
if comparison_experiment_id is None:
1380+
comparison_experiment_id = _get_persisted_base_experiment_id(experiment)
1381+
summary = experiment.summarize(
1382+
summarize_scores=evaluator.summarize_scores,
1383+
comparison_experiment_id=comparison_experiment_id,
1384+
)
13711385
else:
13721386
summary = build_local_summary(evaluator, results)
13731387

py/src/braintrust/logger.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3969,6 +3969,12 @@ def summarize(
39693969
if base_experiment:
39703970
comparison_experiment_id = base_experiment.id
39713971
comparison_experiment_name = base_experiment.name
3972+
else:
3973+
try:
3974+
comparison_experiment = state.api_conn().get_json(f"v1/experiment/{comparison_experiment_id}")
3975+
comparison_experiment_name = comparison_experiment.get("name")
3976+
except Exception:
3977+
pass
39723978

39733979
try:
39743980
summary_items = state.api_conn().get_json(

py/src/braintrust/test_framework.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import importlib.util
22
import re
33
import sys
4-
from unittest.mock import MagicMock
4+
from unittest.mock import MagicMock, patch
55

66
import pytest
77
from braintrust.logger import BraintrustState
@@ -78,6 +78,92 @@ def exact_match(input_value, output, expected):
7878
assert result.summary.scores["exact_match"].score == 1.0
7979

8080

81+
@pytest.mark.asyncio
82+
async def test_run_evaluator_forwards_base_experiment_id_to_summary(with_memory_logger, with_simulate_login):
83+
def exact_match(input_value, output, expected):
84+
return 1.0 if output == expected else 0.0
85+
86+
evaluator = Evaluator(
87+
project_name="test-project",
88+
eval_name="test-evaluator",
89+
data=[EvalCase(input=1, expected=1)],
90+
task=lambda input_value: input_value,
91+
scores=[exact_match],
92+
experiment_name=None,
93+
metadata=None,
94+
base_experiment_id="base-exp-id",
95+
)
96+
97+
exp = init_test_exp("test-evaluator", "test-project")
98+
expected_summary = MagicMock()
99+
exp.summarize = MagicMock(return_value=expected_summary)
100+
101+
result = await run_evaluator(experiment=exp, evaluator=evaluator, position=None, filters=[])
102+
103+
assert result.summary is expected_summary
104+
exp.summarize.assert_called_once_with(
105+
summarize_scores=True,
106+
comparison_experiment_id="base-exp-id",
107+
)
108+
109+
110+
@pytest.mark.asyncio
111+
async def test_run_evaluator_forwards_persisted_base_experiment_id_to_summary(with_memory_logger, with_simulate_login):
112+
def exact_match(input_value, output, expected):
113+
return 1.0 if output == expected else 0.0
114+
115+
evaluator = Evaluator(
116+
project_name="test-project",
117+
eval_name="test-evaluator",
118+
data=[EvalCase(input=1, expected=1)],
119+
task=lambda input_value: input_value,
120+
scores=[exact_match],
121+
experiment_name=None,
122+
metadata=None,
123+
base_experiment_name="base-exp",
124+
)
125+
126+
exp = init_test_exp("test-evaluator", "test-project")
127+
exp.data["base_exp_id"] = "base-exp-id"
128+
expected_summary = MagicMock()
129+
exp.summarize = MagicMock(return_value=expected_summary)
130+
131+
result = await run_evaluator(experiment=exp, evaluator=evaluator, position=None, filters=[])
132+
133+
assert result.summary is expected_summary
134+
exp.summarize.assert_called_once_with(
135+
summarize_scores=True,
136+
comparison_experiment_id="base-exp-id",
137+
)
138+
139+
140+
def test_experiment_summarize_resolves_explicit_comparison_name(with_memory_logger, with_simulate_login):
141+
exp = init_test_exp("test-evaluator", "test-project")
142+
mock_conn = MagicMock()
143+
144+
def get_json(path, args=None):
145+
if path == "v1/experiment/base-exp-id":
146+
return {"name": "base-exp"}
147+
if path == "experiment-comparison2":
148+
return {"scores": {}, "metrics": {}}
149+
raise AssertionError(f"Unexpected get_json call: {path}, {args}")
150+
151+
mock_conn.get_json.side_effect = get_json
152+
153+
with patch.object(exp.state, "api_conn", return_value=mock_conn):
154+
summary = exp.summarize(comparison_experiment_id="base-exp-id")
155+
156+
assert summary.comparison_experiment_name == "base-exp"
157+
mock_conn.get_json.assert_any_call("v1/experiment/base-exp-id")
158+
mock_conn.get_json.assert_any_call(
159+
"experiment-comparison2",
160+
args={
161+
"experiment_id": "test-evaluator",
162+
"base_experiment_id": "base-exp-id",
163+
},
164+
)
165+
166+
81167
@pytest.mark.asyncio
82168
@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed")
83169
async def test_run_evaluator_exposes_validated_parameter_values_to_hooks():

0 commit comments

Comments
 (0)