Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion py/src/braintrust/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,10 @@ async def run_evaluator(
)

if experiment:
summary = experiment.summarize(summarize_scores=evaluator.summarize_scores)
summary = experiment.summarize(
summarize_scores=evaluator.summarize_scores,
comparison_experiment_id=evaluator.base_experiment_id,
)
else:
summary = build_local_summary(evaluator, results)

Expand Down
6 changes: 6 additions & 0 deletions py/src/braintrust/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3969,6 +3969,12 @@ def summarize(
if base_experiment:
comparison_experiment_id = base_experiment.id
comparison_experiment_name = base_experiment.name
else:
try:
comparison_experiment = state.api_conn().get_json(f"v1/experiment/{comparison_experiment_id}")
comparison_experiment_name = comparison_experiment.get("name")
except Exception:
pass

try:
summary_items = state.api_conn().get_json(
Expand Down
58 changes: 57 additions & 1 deletion py/src/braintrust/test_framework.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib.util
import re
import sys
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
from braintrust.logger import BraintrustState
Expand Down Expand Up @@ -78,6 +78,62 @@ def exact_match(input_value, output, expected):
assert result.summary.scores["exact_match"].score == 1.0


@pytest.mark.asyncio
async def test_run_evaluator_forwards_base_experiment_id_to_summary(with_memory_logger, with_simulate_login):
def exact_match(input_value, output, expected):
return 1.0 if output == expected else 0.0

evaluator = Evaluator(
project_name="test-project",
eval_name="test-evaluator",
data=[EvalCase(input=1, expected=1)],
task=lambda input_value: input_value,
scores=[exact_match],
experiment_name=None,
metadata=None,
base_experiment_id="base-exp-id",
)

exp = init_test_exp("test-evaluator", "test-project")
expected_summary = MagicMock()
exp.summarize = MagicMock(return_value=expected_summary)

result = await run_evaluator(experiment=exp, evaluator=evaluator, position=None, filters=[])

assert result.summary is expected_summary
exp.summarize.assert_called_once_with(
summarize_scores=True,
comparison_experiment_id="base-exp-id",
)


def test_experiment_summarize_resolves_explicit_comparison_name(with_memory_logger, with_simulate_login):
exp = init_test_exp("test-evaluator", "test-project")
mock_conn = MagicMock()

def get_json(path, args=None):
if path == "v1/experiment/base-exp-id":
return {"name": "base-exp"}
if path == "experiment-comparison2":
return {"scores": {}, "metrics": {}}
raise AssertionError(f"Unexpected get_json call: {path}, {args}")

mock_conn.get_json.side_effect = get_json

with patch.object(exp.state, "api_conn", return_value=mock_conn):
summary = exp.summarize(comparison_experiment_id="base-exp-id")

assert summary.comparison_experiment_name == "base-exp"
mock_conn.get_json.assert_any_call("v1/experiment/base-exp-id")
mock_conn.get_json.assert_any_call(
"experiment-comparison2",
args={
"experiment_id": "test-evaluator",
"base_experiment_id": "base-exp-id",
},
)


@pytest.mark.asyncio
@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed")
async def test_run_evaluator_exposes_validated_parameter_values_to_hooks():
Expand Down
Loading