Skip to content

Commit 11b6bcd

Browse files
committed
fix(experiments): pass base_experiment_id to summarize
1 parent 28fae42 commit 11b6bcd

3 files changed

Lines changed: 67 additions & 2 deletions

File tree

py/src/braintrust/framework.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1367,7 +1367,10 @@ async def run_evaluator(
13671367
)
13681368

13691369
if experiment:
1370-
summary = experiment.summarize(summarize_scores=evaluator.summarize_scores)
1370+
summary = experiment.summarize(
1371+
summarize_scores=evaluator.summarize_scores,
1372+
comparison_experiment_id=evaluator.base_experiment_id,
1373+
)
13711374
else:
13721375
summary = build_local_summary(evaluator, results)
13731376

py/src/braintrust/logger.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3969,6 +3969,12 @@ def summarize(
39693969
if base_experiment:
39703970
comparison_experiment_id = base_experiment.id
39713971
comparison_experiment_name = base_experiment.name
3972+
else:
3973+
try:
3974+
comparison_experiment = state.api_conn().get_json(f"v1/experiment/{comparison_experiment_id}")
3975+
comparison_experiment_name = comparison_experiment.get("name")
3976+
except Exception:
3977+
pass
39723978

39733979
try:
39743980
summary_items = state.api_conn().get_json(

py/src/braintrust/test_framework.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import importlib.util
22
import re
33
import sys
4-
from unittest.mock import MagicMock
4+
from unittest.mock import MagicMock, patch
55

66
import pytest
77
from braintrust.logger import BraintrustState
@@ -78,6 +78,62 @@ def exact_match(input_value, output, expected):
7878
assert result.summary.scores["exact_match"].score == 1.0
7979

8080

81+
@pytest.mark.asyncio
82+
async def test_run_evaluator_forwards_base_experiment_id_to_summary(with_memory_logger, with_simulate_login):
83+
def exact_match(input_value, output, expected):
84+
return 1.0 if output == expected else 0.0
85+
86+
evaluator = Evaluator(
87+
project_name="test-project",
88+
eval_name="test-evaluator",
89+
data=[EvalCase(input=1, expected=1)],
90+
task=lambda input_value: input_value,
91+
scores=[exact_match],
92+
experiment_name=None,
93+
metadata=None,
94+
base_experiment_id="base-exp-id",
95+
)
96+
97+
exp = init_test_exp("test-evaluator", "test-project")
98+
expected_summary = MagicMock()
99+
exp.summarize = MagicMock(return_value=expected_summary)
100+
101+
result = await run_evaluator(experiment=exp, evaluator=evaluator, position=None, filters=[])
102+
103+
assert result.summary is expected_summary
104+
exp.summarize.assert_called_once_with(
105+
summarize_scores=True,
106+
comparison_experiment_id="base-exp-id",
107+
)
108+
109+
110+
def test_experiment_summarize_resolves_explicit_comparison_name(with_memory_logger, with_simulate_login):
111+
exp = init_test_exp("test-evaluator", "test-project")
112+
mock_conn = MagicMock()
113+
114+
def get_json(path, args=None):
115+
if path == "v1/experiment/base-exp-id":
116+
return {"name": "base-exp"}
117+
if path == "experiment-comparison2":
118+
return {"scores": {}, "metrics": {}}
119+
raise AssertionError(f"Unexpected get_json call: {path}, {args}")
120+
121+
mock_conn.get_json.side_effect = get_json
122+
123+
with patch.object(exp.state, "api_conn", return_value=mock_conn):
124+
summary = exp.summarize(comparison_experiment_id="base-exp-id")
125+
126+
assert summary.comparison_experiment_name == "base-exp"
127+
mock_conn.get_json.assert_any_call("v1/experiment/base-exp-id")
128+
mock_conn.get_json.assert_any_call(
129+
"experiment-comparison2",
130+
args={
131+
"experiment_id": "test-evaluator",
132+
"base_experiment_id": "base-exp-id",
133+
},
134+
)
135+
136+
81137
@pytest.mark.asyncio
82138
@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed")
83139
async def test_run_evaluator_exposes_validated_parameter_values_to_hooks():

0 commit comments

Comments
 (0)