Skip to content

Commit 91443dc

Browse files
authored
Dynamic SUTs refactor (#1141)
* New SUTFactory object. Move classify_sut_uid there * Remove pre-flight registration of dynamic suts. Flesh out sut factory make_instance * Dynamic SUTs inject secrets. Non-hf factories return hydrated SUT instances * SUT factory also makes dynamic suts. * Update HF factories * try solve circular imports * Use SUT factory for UID validation-- dynamic suts default to true * check secrets bypasses dynamic suts using sut factory * remove validate_uid --sut callback
1 parent e9115ec commit 91443dc

22 files changed

Lines changed: 600 additions & 330 deletions

src/modelbench/run.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@
1212
import sys
1313
from collections import defaultdict
1414
from datetime import datetime, timezone
15-
from typing import List
1615

1716
import click
1817

1918
import termcolor
2019
from click import echo
21-
from modelgauge.command_line import compact_uid_list, validate_uid
2220
from modelgauge.config import load_secrets_from_config, write_default_config
2321
from modelgauge.load_plugins import load_plugins
2422
from modelgauge.locales import DEFAULT_LOCALE, LOCALES, PUBLISHED_LOCALES, validate_locale
@@ -82,7 +80,7 @@ def at_end(result, **kwargs):
8280
@cli.command(help="List known suts")
8381
@local_plugin_dir_option
8482
def list_suts():
85-
print(compact_uid_list(SUTS))
83+
print(SUTS.compact_uid_list())
8684

8785

8886
@cli.command(help="run a benchmark")
@@ -102,7 +100,6 @@ def list_suts():
102100
multiple=False,
103101
help="SUT UID to run",
104102
required=True,
105-
callback=validate_uid,
106103
)
107104
@click.option("--anonymize", type=int, help="Randon number seed for consistent anonymization SUTs")
108105
@click.option("--threads", default=32, help="How many threads to use per stage")

src/modelgauge/command_line.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from modelgauge.annotator_registry import ANNOTATORS
77
from modelgauge.config import write_default_config
88
from modelgauge.load_plugins import load_plugins
9-
from modelgauge.preflight import validate_sut_uid, listify
9+
from modelgauge.preflight import listify
1010
from modelgauge.sut import SUTOptions
11+
from modelgauge.sut_factory import SUT_FACTORY
1112
from modelgauge.test_registry import TESTS
1213

1314

@@ -42,12 +43,6 @@ def load_local_plugins(_, __, path: pathlib.Path):
4243
__import__(plugin.name)
4344

4445

45-
def compact_uid_list(registry) -> str:
46-
valid_uids = sorted(registry.keys(), key=lambda x: x.lower())
47-
valid_uids_str = "\n\t".join(valid_uids)
48-
return "\t" + valid_uids_str
49-
50-
5146
# Define some reusable options
5247
DATA_DIR_OPTION = click.option(
5348
"--data-dir",
@@ -105,14 +100,13 @@ def validate_uid(ctx, param, value):
105100
"""Callback function for click.option UID validation.
106101
Raises a BadParameter exception if the user-supplied arg(s) are not valid UIDs.
107102
Applicable for parameters '--test', '--sut', and '--annotator'.
108-
SUT IDs are validated in validate_sut_uid via this function.
109103
If no UID is provided (e.g. an empty list or `None`), the value is returned as-is.
110104
"""
111105
if not value:
112106
return value
113107
# Identify what object we are validating UIDs for.
114108
if "--sut" in param.opts:
115-
return validate_sut_uid(value)
109+
registry = SUT_FACTORY
116110
elif "--test" in param.opts:
117111
registry = TESTS
118112
elif "--annotator" in param.opts:
@@ -125,7 +119,7 @@ def validate_uid(ctx, param, value):
125119

126120
unknown_uids = []
127121
for uid in values:
128-
if uid not in registry.keys():
122+
if not registry.knows(uid):
129123
unknown_uids.append(uid)
130124
if len(unknown_uids) == 0:
131125
return value
@@ -137,6 +131,6 @@ def validate_uid(ctx, param, value):
137131
# this is used for all types of UIDs, not just SUTs
138132
def _bad_uid_error(registry, message, hint=""):
139133
raise click.BadParameter(
140-
f"{message}.\nValid options are:\n{compact_uid_list(registry)}",
134+
f"{message}.\nValid options are:\n{registry.compact_uid_list()}",
141135
param_hint=hint,
142136
)

src/modelgauge/dynamic_sut_factory.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from abc import ABC, abstractmethod
2+
from typing import Sequence
23

4+
from modelgauge.dependency_injection import inject_dependencies
5+
from modelgauge.secret_values import InjectSecret, RawSecrets
36
from modelgauge.dynamic_sut_metadata import DynamicSUTMetadata
4-
5-
from modelgauge.secret_values import InjectSecret
7+
from modelgauge.sut import SUT
68

79

810
class ModelNotSupportedError(Exception):
@@ -27,7 +29,18 @@ class UnknownSUTMakerError(Exception):
2729

2830

2931
class DynamicSUTFactory(ABC):
32+
def __init__(self, raw_secrets: RawSecrets):
33+
self.raw_secrets = raw_secrets
34+
35+
def injected_secrets(self):
36+
"""Return the injected secrets as specified by `get_secrets`."""
37+
return inject_dependencies(self.get_secrets(), {}, secrets=self.raw_secrets)[0]
38+
3039
@staticmethod
3140
@abstractmethod
32-
def make_sut(sut_metadata: DynamicSUTMetadata):
41+
def get_secrets() -> list[InjectSecret]:
42+
pass
43+
44+
@abstractmethod
45+
def make_sut(self, sut_metadata: DynamicSUTMetadata) -> SUT:
3346
pass

src/modelgauge/dynamic_sut_finder.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

src/modelgauge/instance_factory.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,7 @@ def _get_entry(self, uid: str) -> FactoryEntry:
8484
try:
8585
entry = self._lookup[uid]
8686
except KeyError:
87-
known_uids = list(self._lookup.keys())
88-
raise KeyError(f"No registration for {uid}. Known uids: {known_uids}")
87+
raise ValueError(f"No registration for {uid}. Known uids:\n{self.compact_uid_list()}")
8988
return entry
9089

9190
def items(self) -> List[Tuple[str, FactoryEntry[_T]]]:
@@ -97,3 +96,12 @@ def keys(self) -> List[str]:
9796
"""List all keys in the registry."""
9897
with self.lock:
9998
return list(self._lookup.keys())
99+
100+
def knows(self, uid: str) -> bool:
101+
"""Check if the factory knows about a given uid."""
102+
return uid in self.keys()
103+
104+
def compact_uid_list(self) -> str:
105+
valid_uids = sorted(list(self._lookup.keys()), key=lambda x: x.lower())
106+
valid_uids_str = "\n\t".join(valid_uids)
107+
return "\t" + valid_uids_str

src/modelgauge/main.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from modelgauge.general import normalize_filename
2626
from modelgauge.instance_factory import FactoryEntry
2727
from modelgauge.load_plugins import list_plugins
28-
from modelgauge.pipeline_runner import AnnotatorRunner, build_runner, PromptPlusAnnotatorRunner, PromptRunner
29-
from modelgauge.preflight import check_secrets, listify, make_sut
28+
from modelgauge.pipeline_runner import build_runner
29+
from modelgauge.preflight import check_secrets, make_sut
3030
from modelgauge.prompt import TextPrompt
3131
from modelgauge.secret_values import get_all_secrets, RawSecrets
3232
from modelgauge.simple_test_runner import run_prompt_response_test
@@ -126,7 +126,7 @@ def list_secrets() -> None:
126126

127127
@modelgauge_cli.command()
128128
@LOCAL_PLUGIN_DIR_OPTION
129-
@click.option("--sut", "-s", help="Which SUT to run.", required=True, callback=validate_uid)
129+
@click.option("--sut", "-s", help="Which SUT to run.", required=True)
130130
@sut_options_options
131131
@click.option("--prompt", help="The full text to send to the SUT.")
132132
@click.option(
@@ -144,9 +144,6 @@ def run_sut(
144144
top_k: Optional[int],
145145
):
146146
"""Send a prompt from the command line to a SUT."""
147-
secrets = load_secrets_from_config()
148-
check_secrets(secrets, sut_uids=[sut])
149-
150147
sut_instance = make_sut(sut)
151148

152149
# Current this only knows how to do prompt response, so assert that is what we have.
@@ -168,7 +165,7 @@ def run_sut(
168165
@modelgauge_cli.command()
169166
@click.option("--test", "-t", help="Which registered TEST to run.", required=True, callback=validate_uid)
170167
@LOCAL_PLUGIN_DIR_OPTION
171-
@click.option("--sut", "-s", help="Which SUT to run.", required=True, multiple=False, callback=validate_uid)
168+
@click.option("--sut", "-s", help="Which SUT to run.", required=True, multiple=False)
172169
@DATA_DIR_OPTION
173170
@MAX_TEST_ITEMS_OPTION
174171
@click.option(
@@ -245,7 +242,6 @@ def run_test(
245242
help="Which SUT to run.",
246243
multiple=False,
247244
required=False,
248-
callback=validate_uid,
249245
)
250246
@click.option(
251247
"annotator_uids",
@@ -309,7 +305,7 @@ def run_job(
309305
sut = make_sut(sut_uid)
310306
if AcceptsTextPrompt not in sut.capabilities:
311307
raise click.BadParameter(f"{sut_uid} does not accept text prompts")
312-
check_secrets(secrets, sut_uids=[sut_uid], annotator_uids=annotator_uids)
308+
check_secrets(secrets, annotator_uids=annotator_uids)
313309
sut_options = create_sut_options(max_tokens, temp, top_p, top_k)
314310
else:
315311
sut = None

src/modelgauge/preflight.py

Lines changed: 5 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
1+
from typing import List
2+
13
from modelgauge.annotator_registry import ANNOTATORS
24
from modelgauge.config import load_secrets_from_config, raise_if_missing_from_config
3-
from modelgauge.dynamic_sut_finder import make_dynamic_sut_for
45
from modelgauge.secret_values import MissingSecretValues
5-
6-
7-
from typing import List
8-
9-
from modelgauge.sut import SUTNotFoundException
6+
from modelgauge.sut_factory import SUT_FACTORY
107
from modelgauge.sut_registry import SUTS
118
from modelgauge.test_registry import TESTS
129

13-
SUT_TYPE_DYNAMIC = "dynamic"
14-
SUT_TYPE_KNOWN = "known"
15-
SUT_TYPE_UNKNOWN = "unknown"
16-
1710

1811
def listify(value):
1912
"""Some functions accept a single UID or a list of them. This turns a single one into a list."""
@@ -33,7 +26,7 @@ def check_secrets(secrets, sut_uids=None, test_uids=None, annotator_uids=None):
3326
"""Checks if all secrets are present for the given UIDs. Raises an error and reports all missing secrets."""
3427
missing_secrets: List[MissingSecretValues] = []
3528
if sut_uids is not None:
36-
missing_secrets.extend(get_missing_secrets(secrets, SUTS, listify(sut_uids)))
29+
missing_secrets.extend(get_missing_secrets(secrets, SUT_FACTORY, listify(sut_uids)))
3730
if test_uids is not None:
3831
missing_secrets.extend(get_missing_secrets(secrets, TESTS, test_uids))
3932
# Check secrets for the annotators in the test as well.
@@ -46,39 +39,9 @@ def check_secrets(secrets, sut_uids=None, test_uids=None, annotator_uids=None):
4639
return True
4740

4841

49-
def classify_sut_uid(uid: str) -> str:
50-
if uid in SUTS.keys():
51-
return SUT_TYPE_KNOWN
52-
elif ":" in uid:
53-
return SUT_TYPE_DYNAMIC
54-
else:
55-
return SUT_TYPE_UNKNOWN
56-
57-
58-
def validate_sut_uid(sut_uid):
59-
# A blank sut uid is OK for some invocations of modelgauge.
60-
# Commands where a non-blank sut uid is required must enforce that with click
61-
if not sut_uid:
62-
return sut_uid
63-
64-
sut_type = classify_sut_uid(sut_uid)
65-
if sut_type == SUT_TYPE_KNOWN:
66-
pass
67-
elif sut_type == SUT_TYPE_DYNAMIC:
68-
dynamic_sut = make_dynamic_sut_for(sut_uid) # a tuple that can be splatted for SUTS.register
69-
if dynamic_sut:
70-
SUTS.register(*dynamic_sut)
71-
else:
72-
raise SUTNotFoundException(f"{sut_uid} is not a valid dynamic SUT UID")
73-
else:
74-
raise SUTNotFoundException(f"{sut_uid} is not a valid SUT UID")
75-
76-
return sut_uid
77-
78-
7942
def make_sut(sut_uid: str):
8043
"""Checks that user has all required secrets and returns instantiated SUT."""
8144
secrets = load_secrets_from_config()
8245
check_secrets(secrets, sut_uids=[sut_uid])
83-
sut = SUTS.make_instance(sut_uid, secrets=secrets)
46+
sut = SUT_FACTORY.make_instance(sut_uid, secrets=secrets)
8447
return sut

src/modelgauge/sut.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414
REFUSAL_RESPONSE = ""
1515

1616

17-
class SUTNotFoundException(Exception):
18-
pass
19-
20-
2117
class SUTOptions(BaseModel):
2218
"""
2319
An exhaustive set of options that could potentially be desired by a SUT.

0 commit comments

Comments
 (0)