diff --git a/src/modelbench/cli.py b/src/modelbench/cli.py index eefe08820..eafda381c 100644 --- a/src/modelbench/cli.py +++ b/src/modelbench/cli.py @@ -62,8 +62,9 @@ def decorator(func): @click.option( "--output-dir", "-o", - default="./run/records", + default="records", type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path), + help="Directory where benchmark records will be saved relative to the run directory", ) @click.option("--max-instances", "-m", type=int, default=None) @click.option("--debug", default=False, is_flag=True) @@ -122,14 +123,23 @@ def wrapper(*args, **kwargs): @click.group() @local_plugin_dir_option -def cli() -> None: +@click.option( + "--run-path", + default="./run", + type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path), +) +@click.pass_context +def cli(ctx: click.Context, run_path) -> None: + ctx.ensure_object(dict) + ctx.obj["run_path"] = run_path + PROMETHEUS.push_metrics() try: faulthandler.register(signal.SIGUSR1, file=sys.stderr, all_threads=True, chain=False) except io.UnsupportedOperation: pass # just an issue with some tests that capture sys.stderr - log_dir = pathlib.Path("run/logs") + log_dir = run_path / "logs" log_dir.mkdir(exist_ok=True, parents=True) filename = log_dir / f'modelbench-{datetime.now().strftime("%y%m%d-%H%M%S")}.log' logging.basicConfig(level=logging.DEBUG, handlers=[get_file_logging_handler(filename)], force=True) @@ -163,7 +173,9 @@ def list_suts(): multiple=False, ) @benchmark_options(GENERAL_PROMPT_SETS, "demo") +@click.pass_context def general_benchmark( + ctx: click.Context, version: str, output_dir: pathlib.Path, max_instances: int | None, @@ -176,15 +188,18 @@ def general_benchmark( prompt_set="demo", evaluator="default", ) -> None: + run_path: pathlib.Path = ctx.obj["run_path"] sut = make_sut(sut_uid) benchmark = GeneralPurposeAiChatBenchmarkV1(locale, prompt_set, evaluator) check_benchmark(benchmark) - run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, output_dir, run_uid, user) + run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, run_path, output_dir, run_uid, user) @benchmark.command("security", help="run a security benchmark") @benchmark_options(SECURITY_JAILBREAK_PROMPT_SETS, "official") +@click.pass_context def security_benchmark( + ctx: click.Context, output_dir: pathlib.Path, max_instances: int | None, debug: bool, @@ -196,26 +211,28 @@ def security_benchmark( prompt_set="official", evaluator="default", ) -> None: + run_path: pathlib.Path = ctx.obj["run_path"] sut = make_sut(sut_uid) benchmark = SecurityBenchmark(locale, prompt_set, evaluator=evaluator) check_benchmark(benchmark) - run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, output_dir, run_uid, user) + run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, run_path, output_dir, run_uid, user) -def run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, output_dir, run_uid, user): +def run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, run_path, outputdir, run_uid, user): start_time = datetime.now(timezone.utc) - run = run_benchmarks_for_sut([benchmark], sut, max_instances, debug=debug, json_logs=json_logs) + run = run_benchmarks_for_sut([benchmark], sut, max_instances, run_path=run_path, debug=debug, json_logs=json_logs) benchmark_scores = score_benchmarks(run) - output_dir.mkdir(exist_ok=True, parents=True) + output_path = run_path / outputdir + output_path.mkdir(exist_ok=True, parents=True) print_summary(benchmark, benchmark_scores) - json_path = output_dir / f"benchmark_record-{benchmark.uid}.json" + json_path = output_path / f"benchmark_record-{benchmark.uid}.json" scores = [score for score in benchmark_scores if score.benchmark_definition == benchmark] dump_json(json_path, start_time, benchmark, scores, run_uid, user) print(f"Wrote record for {benchmark.uid} to {json_path}.") # export the annotations separately annotations = {"job_id": run.run_id, "annotations": run.compile_annotations()} - annotation_path = output_dir / f"annotations-{benchmark.uid}.json" + annotation_path = output_path / f"annotations-{benchmark.uid}.json" with open(annotation_path, "w") as annotation_records: annotation_records.write(json.dumps(annotations)) print(f"Wrote annotations for {benchmark.uid} to {annotation_path}.") diff --git a/tests/modelbench_tests/test_run.py b/tests/modelbench_tests/test_run.py index 5d5dce957..0d56e7e4e 100644 --- a/tests/modelbench_tests/test_run.py +++ b/tests/modelbench_tests/test_run.py @@ -290,8 +290,19 @@ def do_print_summary(self, monkeypatch): monkeypatch.setattr(modelbench.cli, "print_summary", MagicMock()) @pytest.fixture - def runner(self): - return CliRunner() + def run_dir(self, tmp_path): + return tmp_path + + @pytest.fixture + def runner(self, run_dir): + runner = CliRunner() + + def invoke(command, args=None, **kwargs): + args = list(args or []) + full_args = ["--run-path", run_dir] + args + return runner.invoke(command, full_args, **kwargs) + + return invoke @pytest.mark.parametrize( "version,locale,prompt_set", @@ -314,7 +325,7 @@ def test_benchmark_basic_run_produces_json( version, locale, prompt_set, - tmp_path, + run_dir, ): benchmark_options = ["--version", version] if locale is not None: @@ -333,22 +344,20 @@ def test_benchmark_basic_run_produces_json( "1", "--sut", sut_uid, - "--output-dir", - str(tmp_path.absolute()), *benchmark_options, ] - result = runner.invoke( + result = runner( cli, command_options, catch_exceptions=False, ) assert result.exit_code == 0 - assert (tmp_path / f"benchmark_record-{benchmark.uid}.json").exists() + assert (run_dir / "records" / f"benchmark_record-{benchmark.uid}.json").exists() - annotation_file_path = tmp_path / f"annotations-{benchmark.uid}.json" + annotation_file_path = run_dir / "records" / f"annotations-{benchmark.uid}.json" assert annotation_file_path.exists() # TODO find a better spot for this test. It's handy here because all the objects are available. - assert annotations_are_correct(tmp_path / f"annotations-{benchmark.uid}.json", prompt_set) + assert annotations_are_correct(annotation_file_path, prompt_set) # TODO: Add test back after calibrating!! # def test_security_benchmark_basic_run_produces_json( @@ -386,7 +395,7 @@ def test_benchmark_basic_run_produces_json( ) @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:nebius:hfrelay;mt=500;t=0.3"]) def test_benchmark_multiple_suts_produces_json( - self, mock_run_benchmarks, runner, version, locale, prompt_set, sut_uid, tmp_path, monkeypatch + self, mock_run_benchmarks, runner, version, locale, prompt_set, sut_uid, run_dir, monkeypatch ): benchmark_options = ["--version", version] @@ -403,7 +412,7 @@ def test_benchmark_multiple_suts_produces_json( mock = MagicMock(return_value=[self.mock_score(sut_uid, benchmark), self.mock_score("demo_yes_no", benchmark)]) monkeypatch.setattr(modelbench.cli, "score_benchmarks", mock) - result = runner.invoke( + result = runner( cli, [ "benchmark", @@ -414,22 +423,20 @@ def test_benchmark_multiple_suts_produces_json( sut_uid, "--sut", "demo_yes_no", - "--output-dir", - str(tmp_path.absolute()), *benchmark_options, ], catch_exceptions=False, ) assert result.exit_code == 0 - assert (tmp_path / f"benchmark_record-{benchmark.uid}.json").exists + assert (run_dir / "records" / f"benchmark_record-{benchmark.uid}.json").exists - def test_benchmark_bad_sut_errors_out(self, runner, tmp_path): + def test_benchmark_bad_sut_errors_out(self, runner): benchmark_options = ["--version", "1.1"] benchmark_options.extend(["--locale", "en_us"]) benchmark_options.extend(["--prompt-set", "practice"]) with pytest.raises(ValueError, match="No registration for bogus"): - _ = runner.invoke( + _ = runner( cli, [ "benchmark", @@ -438,15 +445,13 @@ def test_benchmark_bad_sut_errors_out(self, runner, tmp_path): "1", "--sut", "bogus", - "--output-dir", - str(tmp_path.absolute()), *benchmark_options, ], catch_exceptions=False, ) with pytest.raises(UnknownSUTMakerError): - _ = runner.invoke( + _ = runner( cli, [ "benchmark", @@ -455,8 +460,6 @@ def test_benchmark_bad_sut_errors_out(self, runner, tmp_path): "1", "--sut", "google/gemma:cohere:bogus", - "--output-dir", - str(tmp_path.absolute()), *benchmark_options, ], catch_exceptions=False, @@ -467,7 +470,7 @@ def test_benchmark_bad_sut_errors_out(self, runner, tmp_path): side_effect=ProviderNotFoundError("bad provider"), ): with pytest.raises(ModelNotSupportedError): - _ = runner.invoke( + _ = runner( cli, [ "benchmark", @@ -476,8 +479,6 @@ def test_benchmark_bad_sut_errors_out(self, runner, tmp_path): "1", "--sut", "meta/llama:notreal:hfrelay", - "--output-dir", - str(tmp_path.absolute()), *benchmark_options, ], catch_exceptions=False, @@ -488,7 +489,7 @@ def test_benchmark_bad_sut_errors_out(self, runner, tmp_path): side_effect=ModelNotSupportedError("bad model"), ): with pytest.raises(ModelNotSupportedError): - _ = runner.invoke( + _ = runner( cli, [ "benchmark", @@ -497,8 +498,6 @@ def test_benchmark_bad_sut_errors_out(self, runner, tmp_path): "1", "--sut", "google/bogus:cohere:hfrelay", - "--output-dir", - str(tmp_path.absolute()), *benchmark_options, ], catch_exceptions=False, @@ -506,13 +505,13 @@ def test_benchmark_bad_sut_errors_out(self, runner, tmp_path): @pytest.mark.parametrize("version", ["0.0", "0.5"]) def test_invalid_benchmark_versions_can_not_be_called(self, version, runner): - result = runner.invoke(cli, ["benchmark", "general", "--version", "0.0"]) + result = runner(cli, ["benchmark", "general", "--version", "0.0"]) assert result.exit_code == 2 assert "Invalid value for '--version'" in result.output @pytest.mark.skip(reason="we have temporarily removed other languages") def test_calls_score_benchmark_with_correct_v1_locale(self, runner, mock_run_benchmarks, sut_uid): - _ = runner.invoke(cli, ["benchmark", "general", "--locale", FR_FR, "--sut", sut_uid]) + _ = runner(cli, ["benchmark", "general", "--locale", FR_FR, "--sut", sut_uid]) benchmark_arg = mock_run_benchmarks.call_args.args[0][0] assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmarkV1) @@ -520,13 +519,13 @@ def test_calls_score_benchmark_with_correct_v1_locale(self, runner, mock_run_ben # TODO: Add back when we add new versions. # def test_calls_score_benchmark_with_correct_version(self, runner, mock_score_benchmarks): - # result = runner.invoke(cli, ["benchmark", "general", "--version", "0.5"]) + # result = runner(cli, ["benchmark", "general", "--version", "0.5"]) # # benchmark_arg = mock_score_benchmarks.call_args.args[0][0] # assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmark) @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:nebius:hfrelay"]) def test_v1_en_us_demo_is_default(self, runner, mock_run_benchmarks, sut_uid): - _ = runner.invoke(cli, ["benchmark", "general", "--sut", sut_uid]) + _ = runner(cli, ["benchmark", "general", "--sut", sut_uid]) benchmark_arg = mock_run_benchmarks.call_args.args[0][0] assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmarkV1) @@ -535,20 +534,20 @@ def test_v1_en_us_demo_is_default(self, runner, mock_run_benchmarks, sut_uid): @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:nebius:hfrelay"]) def test_nonexistent_benchmark_prompt_sets_can_not_be_called(self, runner, sut_uid): - result = runner.invoke(cli, ["benchmark", "general", "--prompt-set", "fake", "--sut", sut_uid]) + result = runner(cli, ["benchmark", "general", "--prompt-set", "fake", "--sut", sut_uid]) assert result.exit_code == 2 assert "Invalid value for '--prompt-set'" in result.output @pytest.mark.parametrize("prompt_set", GENERAL_PROMPT_SETS.keys()) @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:nebius:hfrelay"]) def test_calls_score_benchmark_with_correct_prompt_set(self, runner, mock_run_benchmarks, prompt_set, sut_uid): - _ = runner.invoke(cli, ["benchmark", "general", "--prompt-set", prompt_set, "--sut", sut_uid]) + _ = runner(cli, ["benchmark", "general", "--prompt-set", prompt_set, "--sut", sut_uid]) benchmark_arg = mock_run_benchmarks.call_args.args[0][0] assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmarkV1) assert benchmark_arg.prompt_set == prompt_set - def test_fails_to_run_uncalibrated_benchmark(self, runner, mock_score_benchmarks, tmp_path, standards_path_patch): + def test_fails_to_run_uncalibrated_benchmark(self, runner, mock_score_benchmarks, standards_path_patch): command_options = [ "benchmark", "general", @@ -556,13 +555,11 @@ def test_fails_to_run_uncalibrated_benchmark(self, runner, mock_score_benchmarks "1", "--sut", "fake-sut", - "--output-dir", - str(tmp_path.absolute()), "--locale", "fr_FR", ] with pytest.raises(NoStandardsFileError) as e: - runner.invoke( + runner( cli, command_options, catch_exceptions=False, @@ -618,7 +615,7 @@ def test_calibrate( *benchmark_options, ] - result = runner.invoke( + result = runner( cli, command_options, catch_exceptions=False, @@ -671,7 +668,7 @@ def test_calibrate_security( *benchmark_options, ] - result = runner.invoke( + result = runner( cli, command_options, catch_exceptions=True, @@ -695,7 +692,7 @@ def test_fails_to_calibrate_benchmark_with_standards(self, runner): "default", ] with pytest.raises(OverwriteStandardsFileError) as e: - runner.invoke( + runner( cli, command_options, catch_exceptions=False, diff --git a/tests/modelgauge_tests/test_cli.py b/tests/modelgauge_tests/test_cli.py index 7d91286de..67de8c497 100644 --- a/tests/modelgauge_tests/test_cli.py +++ b/tests/modelgauge_tests/test_cli.py @@ -123,10 +123,23 @@ def test_run_annotator_demo(): @pytest.mark.parametrize("test", ["demo_01", "demo_02", "demo_03"]) -def test_run_test_demos(sut_uid, test): - result = run_cli("run-test", "--test", test, "--sut", sut_uid, "--max-test-items", "1") +def test_run_test_demos(sut_uid, test, tmp_path): + tmp_output_file = tmp_path / "output.json" + result = run_cli( + "run-test", + "--test", + test, + "--sut", + sut_uid, + "--max-test-items", + "1", + "--data-dir", + str(tmp_path), + "--output-file", + str(tmp_output_file), + ) assert result.exit_code == 0 - assert re.search(r"Full TestRecord json written to output", result.output) + assert re.search(rf"Full TestRecord json written to {str(tmp_output_file)}", result.output) def test_run_test_invalid_sut_uid():