Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions src/modelbench/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def decorator(func):
@click.option(
"--output-dir",
"-o",
default="./run/records",
default="records",
type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path),
help="Directory where benchmark records will be saved relative to the run directory",
)
@click.option("--max-instances", "-m", type=int, default=None)
@click.option("--debug", default=False, is_flag=True)
Expand Down Expand Up @@ -122,14 +123,23 @@ def wrapper(*args, **kwargs):

@click.group()
@local_plugin_dir_option
def cli() -> None:
@click.option(
"--run-path",
default="./run",
type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path),
)
@click.pass_context
def cli(ctx: click.Context, run_path) -> None:
ctx.ensure_object(dict)
ctx.obj["run_path"] = run_path

PROMETHEUS.push_metrics()
try:
faulthandler.register(signal.SIGUSR1, file=sys.stderr, all_threads=True, chain=False)
except io.UnsupportedOperation:
pass # just an issue with some tests that capture sys.stderr

log_dir = pathlib.Path("run/logs")
log_dir = run_path / "logs"
log_dir.mkdir(exist_ok=True, parents=True)
filename = log_dir / f'modelbench-{datetime.now().strftime("%y%m%d-%H%M%S")}.log'
logging.basicConfig(level=logging.DEBUG, handlers=[get_file_logging_handler(filename)], force=True)
Expand Down Expand Up @@ -163,7 +173,9 @@ def list_suts():
multiple=False,
)
@benchmark_options(GENERAL_PROMPT_SETS, "demo")
@click.pass_context
def general_benchmark(
ctx: click.Context,
version: str,
output_dir: pathlib.Path,
max_instances: int | None,
Expand All @@ -176,15 +188,18 @@ def general_benchmark(
prompt_set="demo",
evaluator="default",
) -> None:
run_path: pathlib.Path = ctx.obj["run_path"]
sut = make_sut(sut_uid)
benchmark = GeneralPurposeAiChatBenchmarkV1(locale, prompt_set, evaluator)
check_benchmark(benchmark)
run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, output_dir, run_uid, user)
run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, run_path, output_dir, run_uid, user)


@benchmark.command("security", help="run a security benchmark")
@benchmark_options(SECURITY_JAILBREAK_PROMPT_SETS, "official")
@click.pass_context
def security_benchmark(
ctx: click.Context,
output_dir: pathlib.Path,
max_instances: int | None,
debug: bool,
Expand All @@ -196,26 +211,28 @@ def security_benchmark(
prompt_set="official",
evaluator="default",
) -> None:
run_path: pathlib.Path = ctx.obj["run_path"]
sut = make_sut(sut_uid)
benchmark = SecurityBenchmark(locale, prompt_set, evaluator=evaluator)
check_benchmark(benchmark)
run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, output_dir, run_uid, user)
run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, run_path, output_dir, run_uid, user)


def run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, output_dir, run_uid, user):
def run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, run_path, outputdir, run_uid, user):
start_time = datetime.now(timezone.utc)
run = run_benchmarks_for_sut([benchmark], sut, max_instances, debug=debug, json_logs=json_logs)
run = run_benchmarks_for_sut([benchmark], sut, max_instances, run_path=run_path, debug=debug, json_logs=json_logs)
benchmark_scores = score_benchmarks(run)
output_dir.mkdir(exist_ok=True, parents=True)
output_path = run_path / outputdir
output_path.mkdir(exist_ok=True, parents=True)
print_summary(benchmark, benchmark_scores)
json_path = output_dir / f"benchmark_record-{benchmark.uid}.json"
json_path = output_path / f"benchmark_record-{benchmark.uid}.json"
scores = [score for score in benchmark_scores if score.benchmark_definition == benchmark]
dump_json(json_path, start_time, benchmark, scores, run_uid, user)
print(f"Wrote record for {benchmark.uid} to {json_path}.")

# export the annotations separately
annotations = {"job_id": run.run_id, "annotations": run.compile_annotations()}
annotation_path = output_dir / f"annotations-{benchmark.uid}.json"
annotation_path = output_path / f"annotations-{benchmark.uid}.json"
with open(annotation_path, "w") as annotation_records:
annotation_records.write(json.dumps(annotations))
print(f"Wrote annotations for {benchmark.uid} to {annotation_path}.")
Expand Down
77 changes: 37 additions & 40 deletions tests/modelbench_tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,19 @@ def do_print_summary(self, monkeypatch):
monkeypatch.setattr(modelbench.cli, "print_summary", MagicMock())

@pytest.fixture
def runner(self):
return CliRunner()
def run_dir(self, tmp_path):
return tmp_path

@pytest.fixture
def runner(self, run_dir):
runner = CliRunner()

def invoke(command, args=None, **kwargs):
args = list(args or [])
full_args = ["--run-path", run_dir] + args
return runner.invoke(command, full_args, **kwargs)

return invoke

@pytest.mark.parametrize(
"version,locale,prompt_set",
Expand All @@ -314,7 +325,7 @@ def test_benchmark_basic_run_produces_json(
version,
locale,
prompt_set,
tmp_path,
run_dir,
):
benchmark_options = ["--version", version]
if locale is not None:
Expand All @@ -333,22 +344,20 @@ def test_benchmark_basic_run_produces_json(
"1",
"--sut",
sut_uid,
"--output-dir",
str(tmp_path.absolute()),
*benchmark_options,
]
result = runner.invoke(
result = runner(
cli,
command_options,
catch_exceptions=False,
)
assert result.exit_code == 0
assert (tmp_path / f"benchmark_record-{benchmark.uid}.json").exists()
assert (run_dir / "records" / f"benchmark_record-{benchmark.uid}.json").exists()

annotation_file_path = tmp_path / f"annotations-{benchmark.uid}.json"
annotation_file_path = run_dir / "records" / f"annotations-{benchmark.uid}.json"
assert annotation_file_path.exists()
# TODO find a better spot for this test. It's handy here because all the objects are available.
assert annotations_are_correct(tmp_path / f"annotations-{benchmark.uid}.json", prompt_set)
assert annotations_are_correct(annotation_file_path, prompt_set)

# TODO: Add test back after calibrating!!
# def test_security_benchmark_basic_run_produces_json(
Expand Down Expand Up @@ -386,7 +395,7 @@ def test_benchmark_basic_run_produces_json(
)
@pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:nebius:hfrelay;mt=500;t=0.3"])
def test_benchmark_multiple_suts_produces_json(
self, mock_run_benchmarks, runner, version, locale, prompt_set, sut_uid, tmp_path, monkeypatch
self, mock_run_benchmarks, runner, version, locale, prompt_set, sut_uid, run_dir, monkeypatch
):

benchmark_options = ["--version", version]
Expand All @@ -403,7 +412,7 @@ def test_benchmark_multiple_suts_produces_json(
mock = MagicMock(return_value=[self.mock_score(sut_uid, benchmark), self.mock_score("demo_yes_no", benchmark)])
monkeypatch.setattr(modelbench.cli, "score_benchmarks", mock)

result = runner.invoke(
result = runner(
cli,
[
"benchmark",
Expand All @@ -414,22 +423,20 @@ def test_benchmark_multiple_suts_produces_json(
sut_uid,
"--sut",
"demo_yes_no",
"--output-dir",
str(tmp_path.absolute()),
*benchmark_options,
],
catch_exceptions=False,
)
assert result.exit_code == 0
assert (tmp_path / f"benchmark_record-{benchmark.uid}.json").exists
assert (run_dir / "records" / f"benchmark_record-{benchmark.uid}.json").exists

def test_benchmark_bad_sut_errors_out(self, runner, tmp_path):
def test_benchmark_bad_sut_errors_out(self, runner):
benchmark_options = ["--version", "1.1"]
benchmark_options.extend(["--locale", "en_us"])
benchmark_options.extend(["--prompt-set", "practice"])

with pytest.raises(ValueError, match="No registration for bogus"):
_ = runner.invoke(
_ = runner(
cli,
[
"benchmark",
Expand All @@ -438,15 +445,13 @@ def test_benchmark_bad_sut_errors_out(self, runner, tmp_path):
"1",
"--sut",
"bogus",
"--output-dir",
str(tmp_path.absolute()),
*benchmark_options,
],
catch_exceptions=False,
)

with pytest.raises(UnknownSUTMakerError):
_ = runner.invoke(
_ = runner(
cli,
[
"benchmark",
Expand All @@ -455,8 +460,6 @@ def test_benchmark_bad_sut_errors_out(self, runner, tmp_path):
"1",
"--sut",
"google/gemma:cohere:bogus",
"--output-dir",
str(tmp_path.absolute()),
*benchmark_options,
],
catch_exceptions=False,
Expand All @@ -467,7 +470,7 @@ def test_benchmark_bad_sut_errors_out(self, runner, tmp_path):
side_effect=ProviderNotFoundError("bad provider"),
):
with pytest.raises(ModelNotSupportedError):
_ = runner.invoke(
_ = runner(
cli,
[
"benchmark",
Expand All @@ -476,8 +479,6 @@ def test_benchmark_bad_sut_errors_out(self, runner, tmp_path):
"1",
"--sut",
"meta/llama:notreal:hfrelay",
"--output-dir",
str(tmp_path.absolute()),
*benchmark_options,
],
catch_exceptions=False,
Expand All @@ -488,7 +489,7 @@ def test_benchmark_bad_sut_errors_out(self, runner, tmp_path):
side_effect=ModelNotSupportedError("bad model"),
):
with pytest.raises(ModelNotSupportedError):
_ = runner.invoke(
_ = runner(
cli,
[
"benchmark",
Expand All @@ -497,36 +498,34 @@ def test_benchmark_bad_sut_errors_out(self, runner, tmp_path):
"1",
"--sut",
"google/bogus:cohere:hfrelay",
"--output-dir",
str(tmp_path.absolute()),
*benchmark_options,
],
catch_exceptions=False,
)

@pytest.mark.parametrize("version", ["0.0", "0.5"])
def test_invalid_benchmark_versions_can_not_be_called(self, version, runner):
result = runner.invoke(cli, ["benchmark", "general", "--version", "0.0"])
result = runner(cli, ["benchmark", "general", "--version", "0.0"])
assert result.exit_code == 2
assert "Invalid value for '--version'" in result.output

@pytest.mark.skip(reason="we have temporarily removed other languages")
def test_calls_score_benchmark_with_correct_v1_locale(self, runner, mock_run_benchmarks, sut_uid):
_ = runner.invoke(cli, ["benchmark", "general", "--locale", FR_FR, "--sut", sut_uid])
_ = runner(cli, ["benchmark", "general", "--locale", FR_FR, "--sut", sut_uid])

benchmark_arg = mock_run_benchmarks.call_args.args[0][0]
assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmarkV1)
assert benchmark_arg.locale == FR_FR

# TODO: Add back when we add new versions.
# def test_calls_score_benchmark_with_correct_version(self, runner, mock_score_benchmarks):
# result = runner.invoke(cli, ["benchmark", "general", "--version", "0.5"])
# result = runner(cli, ["benchmark", "general", "--version", "0.5"])
#
# benchmark_arg = mock_score_benchmarks.call_args.args[0][0]
# assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmark)
@pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:nebius:hfrelay"])
def test_v1_en_us_demo_is_default(self, runner, mock_run_benchmarks, sut_uid):
_ = runner.invoke(cli, ["benchmark", "general", "--sut", sut_uid])
_ = runner(cli, ["benchmark", "general", "--sut", sut_uid])

benchmark_arg = mock_run_benchmarks.call_args.args[0][0]
assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmarkV1)
Expand All @@ -535,34 +534,32 @@ def test_v1_en_us_demo_is_default(self, runner, mock_run_benchmarks, sut_uid):

@pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:nebius:hfrelay"])
def test_nonexistent_benchmark_prompt_sets_can_not_be_called(self, runner, sut_uid):
result = runner.invoke(cli, ["benchmark", "general", "--prompt-set", "fake", "--sut", sut_uid])
result = runner(cli, ["benchmark", "general", "--prompt-set", "fake", "--sut", sut_uid])
assert result.exit_code == 2
assert "Invalid value for '--prompt-set'" in result.output

@pytest.mark.parametrize("prompt_set", GENERAL_PROMPT_SETS.keys())
@pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:nebius:hfrelay"])
def test_calls_score_benchmark_with_correct_prompt_set(self, runner, mock_run_benchmarks, prompt_set, sut_uid):
_ = runner.invoke(cli, ["benchmark", "general", "--prompt-set", prompt_set, "--sut", sut_uid])
_ = runner(cli, ["benchmark", "general", "--prompt-set", prompt_set, "--sut", sut_uid])

benchmark_arg = mock_run_benchmarks.call_args.args[0][0]
assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmarkV1)
assert benchmark_arg.prompt_set == prompt_set

def test_fails_to_run_uncalibrated_benchmark(self, runner, mock_score_benchmarks, tmp_path, standards_path_patch):
def test_fails_to_run_uncalibrated_benchmark(self, runner, mock_score_benchmarks, standards_path_patch):
command_options = [
"benchmark",
"general",
"-m",
"1",
"--sut",
"fake-sut",
"--output-dir",
str(tmp_path.absolute()),
"--locale",
"fr_FR",
]
with pytest.raises(NoStandardsFileError) as e:
runner.invoke(
runner(
cli,
command_options,
catch_exceptions=False,
Expand Down Expand Up @@ -618,7 +615,7 @@ def test_calibrate(
*benchmark_options,
]

result = runner.invoke(
result = runner(
cli,
command_options,
catch_exceptions=False,
Expand Down Expand Up @@ -671,7 +668,7 @@ def test_calibrate_security(
*benchmark_options,
]

result = runner.invoke(
result = runner(
cli,
command_options,
catch_exceptions=True,
Expand All @@ -695,7 +692,7 @@ def test_fails_to_calibrate_benchmark_with_standards(self, runner):
"default",
]
with pytest.raises(OverwriteStandardsFileError) as e:
runner.invoke(
runner(
cli,
command_options,
catch_exceptions=False,
Expand Down
19 changes: 16 additions & 3 deletions tests/modelgauge_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,23 @@ def test_run_annotator_demo():


@pytest.mark.parametrize("test", ["demo_01", "demo_02", "demo_03"])
def test_run_test_demos(sut_uid, test):
result = run_cli("run-test", "--test", test, "--sut", sut_uid, "--max-test-items", "1")
def test_run_test_demos(sut_uid, test, tmp_path):
tmp_output_file = tmp_path / "output.json"
result = run_cli(
"run-test",
"--test",
test,
"--sut",
sut_uid,
"--max-test-items",
"1",
"--data-dir",
str(tmp_path),
"--output-file",
str(tmp_output_file),
)
assert result.exit_code == 0
assert re.search(r"Full TestRecord json written to output", result.output)
assert re.search(rf"Full TestRecord json written to {str(tmp_output_file)}", result.output)


def test_run_test_invalid_sut_uid():
Expand Down