Skip to content

Commit c7369a6

Browse files
authored
Allow run_uid (#1405)
CLI accepts --run-uid as an argument and will use that for the run_uid of the benchmark run instead of an automatically generated run_uid.
1 parent 77bda0f commit c7369a6

3 files changed

Lines changed: 37 additions & 14 deletions

File tree

src/modelbench/cli.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616
from rich.table import Table
1717

1818
import modelgauge.annotators.cheval.registration # noqa: F401
19+
from modelbench.benchmark_runner import BenchmarkRunner, JsonRunTracker, TqdmRunTracker
20+
from modelbench.benchmarks import GeneralPurposeAiChatBenchmarkV1, SecurityBenchmark
21+
from modelbench.consistency_checker import (
22+
ConsistencyChecker,
23+
summarize_consistency_check_results,
24+
)
25+
from modelbench.record import dump_json
26+
from modelbench.standards import Standards
1927
from modelgauge.config import load_secrets_from_config, write_default_config
2028
from modelgauge.load_namespaces import load_namespaces
2129
from modelgauge.locales import DEFAULT_LOCALE, LOCALES
@@ -25,12 +33,6 @@
2533
from modelgauge.prompt_sets import GENERAL_PROMPT_SETS, SECURITY_JAILBREAK_PROMPT_SETS
2634
from modelgauge.sut_registry import SUTS
2735

28-
from modelbench.benchmark_runner import BenchmarkRunner, JsonRunTracker, TqdmRunTracker
29-
from modelbench.benchmarks import GeneralPurposeAiChatBenchmarkV1, SecurityBenchmark
30-
from modelbench.standards import Standards
31-
from modelbench.consistency_checker import ConsistencyChecker, summarize_consistency_check_results
32-
from modelbench.record import dump_json
33-
3436

3537
def load_local_plugins(_, __, path: pathlib.Path):
3638
path_str = str(path)
@@ -90,6 +92,12 @@ def decorator(func):
9092
help="Which evaluator to use",
9193
show_default=True,
9294
)
95+
@click.option(
96+
"--run-uid",
97+
type=str,
98+
required=False,
99+
help="The run_uid for the run if provided, otherwise one will be generated",
100+
)
93101
@local_plugin_dir_option
94102
@wraps(func)
95103
def wrapper(*args, **kwargs):
@@ -151,13 +159,14 @@ def general_benchmark(
151159
json_logs: bool,
152160
sut_uid: str,
153161
locale: str,
162+
run_uid: str,
154163
prompt_set="demo",
155164
evaluator="default",
156165
) -> None:
157166
sut = make_sut(sut_uid)
158167
benchmark = GeneralPurposeAiChatBenchmarkV1(locale, prompt_set, evaluator)
159168
check_benchmark(benchmark)
160-
run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, output_dir)
169+
run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, output_dir, run_uid)
161170

162171

163172
@benchmark.command("security", help="run a security benchmark")
@@ -169,17 +178,18 @@ def security_benchmark(
169178
json_logs: bool,
170179
sut_uid: str,
171180
locale: str,
181+
run_uid: str,
172182
prompt_set="official",
173183
evaluator="default",
174184
) -> None:
175185
sut = make_sut(sut_uid)
176186
benchmark = SecurityBenchmark(locale, prompt_set, evaluator=evaluator)
177187
check_benchmark(benchmark)
178188

179-
run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, output_dir)
189+
run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, output_dir, run_uid)
180190

181191

182-
def run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, output_dir):
192+
def run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, output_dir, run_uid):
183193
start_time = datetime.now(timezone.utc)
184194
run = run_benchmarks_for_sut([benchmark], sut, max_instances, debug=debug, json_logs=json_logs)
185195

@@ -188,7 +198,7 @@ def run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, ou
188198
print_summary(benchmark, benchmark_scores)
189199
json_path = output_dir / f"benchmark_record-{benchmark.uid}.json"
190200
scores = [score for score in benchmark_scores if score.benchmark_definition == benchmark]
191-
dump_json(json_path, start_time, benchmark, scores)
201+
dump_json(json_path, start_time, benchmark, scores, run_uid)
192202
print(f"Wrote record for {benchmark.uid} to {json_path}.")
193203
run_consistency_check(run.journal_path, verbose=True)
194204

src/modelbench/record.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Sequence
88

99
import pydantic
10+
1011
from modelbench.benchmarks import BaseBenchmarkScore, BenchmarkDefinition
1112
from modelbench.hazards import HazardDefinition, HazardScore
1213
from modelgauge.base_test import BaseTest
@@ -76,12 +77,14 @@ def dump_json(
7677
start_time: datetime.time,
7778
benchmark: BenchmarkDefinition,
7879
benchmark_scores: Sequence[BaseBenchmarkScore],
80+
run_uid: str | None,
7981
):
82+
_run_uid = run_uid if run_uid else f"run-{benchmark.uid}-{start_time.strftime('%Y%m%d-%H%M%S')}"
8083
with open(json_path, "w") as f:
8184
output = {
8285
"_metadata": benchmark_metadata(),
8386
"benchmark": (benchmark),
84-
"run_uid": f"run-{benchmark.uid}-{start_time.strftime('%Y%m%d-%H%M%S')}",
87+
"run_uid": _run_uid,
8588
"scores": (benchmark_scores),
8689
}
8790
json.dump(output, f, cls=BenchmarkScoreEncoder, indent=4)

tests/modelbench_tests/test_record.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88

99
import pytest
1010

11-
from modelbench.benchmarks import BenchmarkScore, GeneralPurposeAiChatBenchmarkV1, SecurityBenchmark, SecurityScore
11+
from modelbench.benchmarks import (
12+
BenchmarkScore,
13+
GeneralPurposeAiChatBenchmarkV1,
14+
SecurityBenchmark,
15+
SecurityScore,
16+
)
1217
from modelbench.hazards import HazardScore, SafeHazardV1, SecurityJailbreakHazard
1318
from modelbench.record import BenchmarkScoreEncoder, benchmark_code_info, dump_json
1419
from modelbench.scoring import ValueEstimate
@@ -282,7 +287,8 @@ def test_benchmark_code_record_without_git(benchmark_score):
282287
assert source["error"] == "git command not found"
283288

284289

285-
def test_dump_json(benchmark_score, tmp_path):
290+
@pytest.mark.parametrize("run_uid", [None, "custom_run_uid"])
291+
def test_dump_json(benchmark_score, tmp_path, run_uid):
286292
# just a smoke test; everything substantial should be tested above.
287293
json_path = tmp_path / "foo.json"
288294
with mock.patch("modelbench.record.benchmark_library_info", lambda: {"skipped by": "test_run.fast_metadata"}):
@@ -291,11 +297,15 @@ def test_dump_json(benchmark_score, tmp_path):
291297
datetime.fromtimestamp(1700000000, timezone.utc),
292298
benchmark_score.benchmark_definition,
293299
[benchmark_score],
300+
run_uid,
294301
)
295302

296303
with open(json_path) as f:
297304
j = json.load(f)
298305
assert "_metadata" in j
299306
assert j["benchmark"]["uid"] == benchmark_score.benchmark_definition.uid
300-
assert j["run_uid"] == "run-" + benchmark_score.benchmark_definition.uid + "-20231114-221320"
307+
if not run_uid:
308+
assert j["run_uid"] == "run-" + benchmark_score.benchmark_definition.uid + "-20231114-221320"
309+
else:
310+
assert j["run_uid"] == run_uid
301311
assert len(j["scores"]) == 1

0 commit comments

Comments
 (0)