Skip to content

Commit 37fd6d3

Browse files
committed
handle base_experiment_name comparison
1 parent 11b6bcd commit 37fd6d3

2 files changed

Lines changed: 42 additions & 1 deletion

File tree

py/src/braintrust/framework.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,14 @@ def _validate_classification_result(value: Any, classifier_name: str) -> Classif
13511351
return classification
13521352

13531353

1354+
def _get_persisted_base_experiment_id(experiment: Experiment) -> str | None:
1355+
try:
1356+
base_experiment_id = experiment.data.get("base_exp_id")
1357+
except Exception:
1358+
return None
1359+
return base_experiment_id if isinstance(base_experiment_id, str) and base_experiment_id else None
1360+
1361+
13541362
async def run_evaluator(
13551363
experiment: Experiment | None,
13561364
evaluator: Evaluator[Input, Output, Expected],
@@ -1367,9 +1375,12 @@ async def run_evaluator(
13671375
)
13681376

13691377
if experiment:
1378+
comparison_experiment_id = evaluator.base_experiment_id
1379+
if comparison_experiment_id is None:
1380+
comparison_experiment_id = _get_persisted_base_experiment_id(experiment)
13701381
summary = experiment.summarize(
13711382
summarize_scores=evaluator.summarize_scores,
1372-
comparison_experiment_id=evaluator.base_experiment_id,
1383+
comparison_experiment_id=comparison_experiment_id,
13731384
)
13741385
else:
13751386
summary = build_local_summary(evaluator, results)

py/src/braintrust/test_framework.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,36 @@ def exact_match(input_value, output, expected):
107107
)
108108

109109

110+
@pytest.mark.asyncio
111+
async def test_run_evaluator_forwards_persisted_base_experiment_id_to_summary(with_memory_logger, with_simulate_login):
112+
def exact_match(input_value, output, expected):
113+
return 1.0 if output == expected else 0.0
114+
115+
evaluator = Evaluator(
116+
project_name="test-project",
117+
eval_name="test-evaluator",
118+
data=[EvalCase(input=1, expected=1)],
119+
task=lambda input_value: input_value,
120+
scores=[exact_match],
121+
experiment_name=None,
122+
metadata=None,
123+
base_experiment_name="base-exp",
124+
)
125+
126+
exp = init_test_exp("test-evaluator", "test-project")
127+
exp.data["base_exp_id"] = "base-exp-id"
128+
expected_summary = MagicMock()
129+
exp.summarize = MagicMock(return_value=expected_summary)
130+
131+
result = await run_evaluator(experiment=exp, evaluator=evaluator, position=None, filters=[])
132+
133+
assert result.summary is expected_summary
134+
exp.summarize.assert_called_once_with(
135+
summarize_scores=True,
136+
comparison_experiment_id="base-exp-id",
137+
)
138+
139+
110140
def test_experiment_summarize_resolves_explicit_comparison_name(with_memory_logger, with_simulate_login):
111141
exp = init_test_exp("test-evaluator", "test-project")
112142
mock_conn = MagicMock()

0 commit comments

Comments
 (0)