Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 72 additions & 86 deletions src/modelbench/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from modelbench.benchmarks import BenchmarkDefinition, BenchmarkScore
from modelbench.cache import DiskCache, MBCache
from modelbench.run_journal import RunJournal

from modelgauge.annotator import CompletionAnnotator
from modelgauge.annotator_registry import ANNOTATORS
from modelgauge.base_test import PromptResponseTest, TestResult
Expand All @@ -29,6 +30,7 @@
from modelgauge.single_turn_prompt_response import TestItem
from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse


logger = logging.getLogger(__name__)
FINISHED_ITEMS = PROMETHEUS.gauge("mm_finished_items", "Finished items")
CACHED_SUT_RESPONSES = PROMETHEUS.counter("mm_cached_sut_responses", "Cached SUT responses")
Expand Down Expand Up @@ -115,7 +117,7 @@ def __init__(self, runner: "TestRunnerBase"):
self.data_dir = runner.data_dir
self.test_data_path = self.data_dir / "tests"
self.secrets = runner.secrets
self.suts = runner.suts
self.sut = runner.sut
self.max_items = runner.max_items
self.tests = []
self.test_annotators = {}
Expand Down Expand Up @@ -286,20 +288,17 @@ def __init__(self, test_run: TestRunBase):
self.test_run = test_run

def handle_item(self, item: TestRunItem):
for sut in self.test_run.suts:
run_item = TestRunItem(item.test, item.test_item, sut)
self.test_run.journal.item_entry(
"queuing item",
run_item,
prompt_text=item.test_item.prompt.text,
)

self.downstream_put(run_item)
run_item = TestRunItem(item.test, item.test_item, self.test_run.sut)
self.test_run.journal.item_entry(
"queuing item",
run_item,
prompt_text=item.test_item.prompt.text,
)
self.downstream_put(run_item)


class TestRunSutWorker(IntermediateCachingPipe):
def __init__(self, test_run: TestRunBase, cache: MBCache, thread_count=1):

super().__init__(cache, thread_count)
self.test_run = test_run

Expand All @@ -314,7 +313,6 @@ def handle_item(self, item: TestRunItem):
raw_response = self.cache[cache_key]
self.test_run.journal.item_entry("using cached sut response", item, response=raw_response)
CACHED_SUT_RESPONSES.inc()

else:
self._debug(f"cache entry not found; processing and saving")
with Timer() as timer:
Expand Down Expand Up @@ -468,50 +466,38 @@ def __init__(self, data_dir: pathlib.Path):
self.debug = False
self.data_dir = data_dir
self.secrets = None
self.suts = []
self.sut = None
self.max_items = 10
self.thread_count = 1
self.run_tracker = NullRunTracker()

def add_sut(self, sut: PromptResponseSUT):
self.suts.append(sut)

def _check_ready_to_run(self):
if not self.secrets:
raise ValueError("must set secrets")

if not self.suts:
raise ValueError("must call add_sut() at least once")
if not self.sut:
raise ValueError("must specify a sut")

def _check_external_services(self, run: TestRunBase):
self._check_suts_working(run)
self._check_sut_working(run)
self._check_annotators_working(run)

def _check_suts_working(self, run: TestRunBase):
def check_sut(sut: PromptResponseSUT):
try:
raw_request = sut.translate_text_prompt(
TextPrompt(text="Why did the chicken cross the road?"), options=SUTOptions(max_tokens=1000)
)
raw_response = sut.evaluate(raw_request)
response: SUTResponse = sut.translate_response(raw_request, raw_response)
if response.text:
return True
else:
raise ValueError(f"initial check failed with no text: {raw_response}")
except Exception as e:
logger.error(f"initial check failure for {sut.uid}", exc_info=e)
print(f"initial check failure for {sut.uid}")
traceback.print_exc()

return False

with ThreadPool(len(run.suts)) as pool:
suts_worked = pool.map(check_sut, self.suts)
if not all(suts_worked):
raise RuntimeError(
f"Not all SUTs are ready to go. Status: {dict(zip([s.uid for s in self.suts], suts_worked))}"
)
def _check_sut_working(self, run: TestRunBase):
try:
raw_request = self.sut.translate_text_prompt(
TextPrompt(text="Why did the chicken cross the road?"), options=SUTOptions(max_tokens=1000)
)
raw_response = self.sut.evaluate(raw_request)
response: SUTResponse = self.sut.translate_response(raw_request, raw_response)
if response.text:
return True
else:
raise RuntimeError(f"initial check failed with no text: {raw_response}")
except Exception as e:
logger.error(f"initial check failure for {self.sut.uid}", exc_info=e)
print(f"initial check failure for {self.sut.uid}")
traceback.print_exc()
raise RuntimeError("SUT is not ready to go.")

def _check_annotators_working(self, run: TestRunBase):
def check_annotator(annotator: CompletionAnnotator):
Expand Down Expand Up @@ -539,15 +525,15 @@ def check_annotator(annotator: CompletionAnnotator):
)

def _calculate_test_results(self, test_run):
for sut in test_run.suts:
for test in test_run.tests:
finished_items = test_run.finished_items_for(sut, test)
test_result = test.aggregate_measurements(finished_items)
test_record = self._make_test_record(test_run, sut, test, test_result)
test_run.add_test_record(test_record)
test_run.journal.raw_entry(
"test scored", sut=sut.uid, test=test.uid, items_finished=len(finished_items), result=test_result
)
sut = test_run.sut
for test in test_run.tests:
finished_items = test_run.finished_items_for(sut, test)
test_result = test.aggregate_measurements(finished_items)
test_record = self._make_test_record(test_run, sut, test, test_result)
test_run.add_test_record(test_record)
test_run.journal.raw_entry(
"test scored", sut=sut.uid, test=test.uid, items_finished=len(finished_items), result=test_result
)

def _make_test_record(self, run, sut, test, test_result):
return TestRecord(
Expand Down Expand Up @@ -577,7 +563,7 @@ def _build_pipeline(self, run):
return pipeline

def _expected_item_count(self, the_run: TestRunBase, pipeline: Pipeline):
return len(the_run.suts) * len(list(pipeline.source.new_item_iterable(quiet=True)))
return len(list(pipeline.source.new_item_iterable(quiet=True)))


class TestRunner(TestRunnerBase):
Expand Down Expand Up @@ -629,7 +615,7 @@ def run(self) -> BenchmarkRun:
run_id=benchmark_run.run_id,
benchmarks=[b.uid for b in benchmark_run.benchmarks],
tests=[t.uid for t in benchmark_run.tests],
suts=[s.uid for s in benchmark_run.suts],
sut=[benchmark_run.sut.uid], # type: ignore
max_items=benchmark_run.max_items,
thread_count=self.thread_count,
)
Expand Down Expand Up @@ -687,39 +673,39 @@ def run(self) -> BenchmarkRun:
return benchmark_run

def _calculate_benchmark_scores(self, benchmark_run):
sut = benchmark_run.sut
for benchmark_definition in benchmark_run.benchmarks:
for sut in benchmark_run.suts:
hazard_scores = []
for hazard in benchmark_definition.hazards():
test_records = {}
for test in hazard.tests(benchmark_run.secrets):
records = benchmark_run.test_records[test.uid][sut.uid]
assert records, f"No records found for {benchmark_definition} {sut.uid} {hazard} {test.uid}"
test_records[test.uid] = records

assert test_records, f"No records found for {benchmark_definition} {sut.uid} {hazard}"

hazard_score = hazard.score(test_records)
hazard_scores.append(hazard_score) # TODO: score needs way less
benchmark_run.journal.raw_entry(
"hazard scored",
benchmark=benchmark_definition.uid,
sut=sut.uid,
hazard=hazard.uid,
score=hazard_score.score.estimate,
reference=hazard.reference_standard(),
samples=hazard_score.score.samples,
numeric_grade=hazard_score.numeric_grade(),
text_grade=hazard_score.text_grade(),
)

benchmark_score = BenchmarkScore(benchmark_definition, sut, hazard_scores, end_time=datetime.now())
benchmark_run.benchmark_scores[benchmark_definition][sut] = benchmark_score
hazard_scores = []
for hazard in benchmark_definition.hazards():
test_records = {}
for test in hazard.tests(benchmark_run.secrets):
records = benchmark_run.test_records[test.uid][sut.uid]
assert records, f"No records found for {benchmark_definition} {sut.uid} {hazard} {test.uid}"
test_records[test.uid] = records

assert test_records, f"No records found for {benchmark_definition} {sut.uid} {hazard}"

hazard_score = hazard.score(test_records)
hazard_scores.append(hazard_score) # TODO: score needs way less
benchmark_run.journal.raw_entry(
"benchmark scored",
"hazard scored",
benchmark=benchmark_definition.uid,
sut=sut.uid,
numeric_grade=benchmark_score.numeric_grade(),
text_grade=benchmark_score.text_grade(),
scoring_log=benchmark_score._scoring_log,
hazard=hazard.uid,
score=hazard_score.score.estimate,
reference=hazard.reference_standard(),
samples=hazard_score.score.samples,
numeric_grade=hazard_score.numeric_grade(),
text_grade=hazard_score.text_grade(),
)

benchmark_score = BenchmarkScore(benchmark_definition, sut, hazard_scores, end_time=datetime.now())
benchmark_run.benchmark_scores[benchmark_definition][sut] = benchmark_score
benchmark_run.journal.raw_entry(
"benchmark scored",
benchmark=benchmark_definition.uid,
sut=sut.uid,
numeric_grade=benchmark_score.numeric_grade(),
text_grade=benchmark_score.text_grade(),
scoring_log=benchmark_score._scoring_log,
)
4 changes: 2 additions & 2 deletions src/modelbench/consistency_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from typing import Dict, List

import casefy
from modelgauge.config import load_secrets_from_config
from modelgauge.test_registry import TESTS
from rich.console import Console
from rich.table import Table

from modelbench.run_journal import journal_reader
from modelgauge.config import load_secrets_from_config
from modelgauge.test_registry import TESTS

LINE_WIDTH = shutil.get_terminal_size(fallback=(120, 50)).columns

Expand Down
55 changes: 32 additions & 23 deletions src/modelbench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@

import termcolor
from click import echo
from modelgauge.command_line import check_secrets, compact_sut_list, make_suts, validate_uid
from modelgauge.command_line import compact_uid_list, validate_uid
from modelgauge.config import load_secrets_from_config, write_default_config
from modelgauge.load_plugins import load_plugins
from modelgauge.locales import DEFAULT_LOCALE, LOCALES, PUBLISHED_LOCALES, validate_locale
from modelgauge.monitoring import PROMETHEUS
from modelgauge.preflight import check_secrets, make_sut
from modelgauge.prompt_sets import PROMPT_SETS, validate_prompt_set
from modelgauge.sut import SUT
from modelgauge.sut_decorator import modelgauge_sut
Expand Down Expand Up @@ -81,7 +82,7 @@ def at_end(result, **kwargs):
@cli.command(help="List known suts")
@local_plugin_dir_option
def list_suts():
print(compact_sut_list(SUTS))
print(compact_uid_list(SUTS))


@cli.command(help="run a benchmark")
Expand All @@ -94,8 +95,16 @@ def list_suts():
@click.option("--max-instances", "-m", type=int, default=100)
@click.option("--debug", default=False, is_flag=True)
@click.option("--json-logs", default=False, is_flag=True, help="Print only machine-readable progress reports")
@click.option("sut_uids", "--sut", "-s", multiple=True, help="SUT uid(s) to run", required=True, callback=validate_uid)
@click.option("--anonymize", type=int, help="Random number seed for consistent anonymization of SUTs")
@click.option(
"sut_uid",
"--sut",
"-s",
multiple=False,
help="SUT UID to run",
required=True,
callback=validate_uid,
)
@click.option("--anonymize", type=int, help="Randon number seed for consistent anonymization SUTs")
@click.option("--threads", default=32, help="How many threads to use per stage")
@click.option(
"--version",
Expand Down Expand Up @@ -135,7 +144,7 @@ def benchmark(
max_instances: int,
debug: bool,
json_logs: bool,
sut_uids: List[str],
sut_uid: str,
anonymize=None,
threads=32,
prompt_set="demo",
Expand All @@ -149,13 +158,12 @@ def benchmark(
locale.lower(),
]

# SUT UIDs are validated in the callback function, so we don't need to validate here
suts = make_suts(sut_uids)
the_sut = make_sut(sut_uid)

# benchmark(s)
benchmarks = [get_benchmark(version, l, prompt_set, evaluator) for l in locales]
run = run_benchmarks_for_suts(
benchmarks, suts, max_instances, debug=debug, json_logs=json_logs, thread_count=threads
run = run_benchmarks_for_sut(
benchmarks, the_sut, max_instances, debug=debug, json_logs=json_logs, thread_count=threads
)
benchmark_scores = score_benchmarks(run)
output_dir.mkdir(exist_ok=True, parents=True)
Expand Down Expand Up @@ -255,17 +263,17 @@ def score_benchmarks(run):
return benchmark_scores


def run_benchmarks_for_suts(benchmarks, suts, max_instances, debug=False, json_logs=False, thread_count=32):
def run_benchmarks_for_sut(benchmarks, sut, max_instances, debug=False, json_logs=False, thread_count=32):
runner = BenchmarkRunner(pathlib.Path("./run"))
runner.secrets = load_secrets_from_config()
runner.benchmarks = benchmarks
runner.suts = suts
runner.sut = sut
runner.max_items = max_instances
runner.debug = debug
runner.thread_count = thread_count
runner.run_tracker = JsonRunTracker() if json_logs else TqdmRunTracker(0.5)

print(f"Starting run for {[b.uid for b in benchmarks]} over {[s.uid for s in suts]}")
print(f"Starting run for {[b.uid for b in benchmarks]} for {sut.uid}")

run = runner.run()

Expand Down Expand Up @@ -343,21 +351,22 @@ def calibrate(update: bool, file) -> None:


def update_standards_to(standards_file):
reference_sut_uids = ["gemma-2-9b-it-hf", "llama-3.1-8b-instruct-turbo-together"]
reference_suts = make_suts(reference_sut_uids)

benchmarks = []
for locale in PUBLISHED_LOCALES:
for prompt_set in PROMPT_SETS.keys():
# we do not want to make demo standards, instead we want to use the practice standards
# we do not want to make demo standards. Instead we want to use the practice standards
if not prompt_set == "demo":
benchmarks.append(get_benchmark("1.0", locale, prompt_set, "ensemble"))
run_result = run_benchmarks_for_suts(benchmarks, reference_suts, None)
all_hazard_numeric_scores = defaultdict(list)
for _, scores_by_sut in run_result.benchmark_scores.items():
for _, benchmark_score in scores_by_sut.items():
for hazard_score in benchmark_score.hazard_scores:
all_hazard_numeric_scores[hazard_score.hazard_definition.uid].append(hazard_score.score.estimate)

reference_sut_uids = ["gemma-2-9b-it-hf", "llama-3.1-8b-instruct-turbo-together"]
for sut_uid in reference_sut_uids:
ref_sut = make_sut(sut_uid)
run_result = run_benchmarks_for_sut(benchmarks, ref_sut, None)
all_hazard_numeric_scores = defaultdict(list)
for _, scores_by_sut in run_result.benchmark_scores.items():
for _, benchmark_score in scores_by_sut.items():
for hazard_score in benchmark_score.hazard_scores:
all_hazard_numeric_scores[hazard_score.hazard_definition.uid].append(hazard_score.score.estimate)

reference_standards = {h: min(s) for h, s in all_hazard_numeric_scores.items() if s}
reference_standards = {k: reference_standards[k] for k in sorted(reference_standards.keys())}
Expand All @@ -374,7 +383,7 @@ def update_standards_to(standards_file):
},
},
"standards": {
"reference_suts": [sut.uid for sut in reference_suts],
"reference_suts": reference_sut_uids,
"reference_standards": reference_standards,
},
}
Expand Down
Loading