Skip to content

Commit 1a33ba2

Browse files
committed
remove support for multiple SUT UIDs as inputs
1 parent 707233a commit 1a33ba2

4 files changed

Lines changed: 60 additions & 35 deletions

File tree

src/modelgauge/command_line.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ def compact_uid_list(registry) -> str:
5151
return "\t" + valid_uids_str
5252

5353

54+
def listify(value):
55+
"""Some functions accept a single UID or a list of them."""
56+
if isinstance(value, str):
57+
return [
58+
value,
59+
]
60+
return value
61+
62+
5463
# Define some reusable options
5564
DATA_DIR_OPTION = click.option(
5665
"--data-dir",
@@ -117,13 +126,8 @@ def _validate_sut_uid(ctx, param, value):
117126
# and that is enforced in the input validation. This function can still
118127
# handle multiple SUT UIDs at a time to match validate_uid's semantics,
119128
# and because it doesn't know or care about the logic in the caller.
120-
if isinstance(value, str):
121-
sut_uids = [value]
122-
else:
123-
sut_uids = value
124-
129+
sut_uids = listify(value)
125130
requested_sut_uids = classify_sut_uids(sut_uids)
126-
127131
valid_sut_uids = requested_sut_uids["known"]
128132
for sut_uid in requested_sut_uids["dynamic"]:
129133
dynamic_sut = make_dynamic_sut_for(sut_uid) # a tuple that can be splatted for SUTS.register
@@ -166,10 +170,7 @@ def validate_uid(ctx, param, value):
166170
raise ValueError(f"Cannot validate UID for unknown parameter: {param.opts}")
167171

168172
# This function handles multi-values and single values.
169-
if isinstance(value, str):
170-
values = [value]
171-
else:
172-
values = value
173+
values = listify(value)
173174

174175
unknown_uids = []
175176
for uid in values:
@@ -184,6 +185,7 @@ def validate_uid(ctx, param, value):
184185

185186
def get_missing_secrets(secrets, registry, uids):
186187
missing_secrets: List[MissingSecretValues] = []
188+
uids = listify(uids)
187189
for uid in uids:
188190
missing_secrets.extend(registry.get_missing_dependencies(uid, secrets=secrets))
189191
return missing_secrets
@@ -210,10 +212,7 @@ def classify_sut_uids(uids):
210212
"""The CLI now accepts dynamic SUT ids (e.g. "deepseek-ai/DeepSeek-V3:together:hfrelay") in addition to
211213
pre-registered SUT ids (e.g. "phi-3.5-moe-instruct"). SUT creation and validation are different
212214
between those two types. This function returns the SUT ids organized by type."""
213-
if isinstance(uids, str):
214-
uids = [
215-
uids,
216-
]
215+
uids = listify(uids)
217216
if len(uids) < 1:
218217
_bad_uid_error(SUTS, "Please provide at least one SUT uid.")
219218
identified = {"known": [], "dynamic": [], "unknown": []}

src/modelgauge/main.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
create_sut_options,
1818
display_header,
1919
display_list_item,
20+
listify,
2021
make_suts,
2122
modelgauge_cli,
2223
sut_options_options,
@@ -173,7 +174,7 @@ def run_sut(
173174
@modelgauge_cli.command()
174175
@click.option("--test", "-t", help="Which registered TEST to run.", required=True, callback=validate_uid)
175176
@LOCAL_PLUGIN_DIR_OPTION
176-
@click.option("--sut", "-s", help="Which SUT to run.", required=True, callback=validate_uid)
177+
@click.option("--sut", "-s", help="Which SUT to run.", required=True, multiple=False, callback=validate_uid)
177178
@DATA_DIR_OPTION
178179
@MAX_TEST_ITEMS_OPTION
179180
@click.option(
@@ -214,25 +215,25 @@ def run_test(
214215
sut,
215216
]
216217
)
217-
sut_obj = suts[0]
218+
sut_instance = suts[0]
218219

219220
# Current this only knows how to do prompt response, so assert that is what we have.
220-
assert isinstance(sut_obj, PromptResponseSUT)
221+
assert isinstance(sut_instance, PromptResponseSUT)
221222
assert isinstance(test_obj, PromptResponseTest)
222223

223-
annotator_objs = []
224+
annotator_instances = []
224225
for annotator in test_obj.get_annotators():
225226
obj = ANNOTATORS.make_instance(annotator, secrets=secrets)
226227
assert isinstance(obj, CompletionAnnotator)
227-
annotator_objs.append(obj)
228+
annotator_instances.append(obj)
228229

229230
if output_file is None:
230231
os.makedirs("output", exist_ok=True)
231232
output_file = os.path.join("output", normalize_filename(f"record_for_{test}_{sut}.json"))
232233
test_record = run_prompt_response_test(
233234
test_obj,
234-
sut_obj,
235-
annotator_objs,
235+
sut_instance,
236+
annotator_instances,
236237
data_dir,
237238
max_test_items,
238239
use_caching=not no_caching,
@@ -253,7 +254,7 @@ def run_test(
253254
"-s",
254255
"--sut",
255256
help="Which SUT to run.",
256-
multiple=True,
257+
multiple=False,
257258
required=False,
258259
callback=validate_uid,
259260
)
@@ -295,7 +296,7 @@ def run_job(
295296
):
296297
"""Run rows in a CSV through (a) SUT(s) and/or a set of annotators.
297298
298-
If running SUTs, the file must have 'UID' and 'Text' columns. The output will be saved to a CSV file.
299+
If running a SUT, the file must have 'UID' and 'Text' columns. The output will be saved to a CSV file.
299300
If running ONLY annotators, the file must have 'UID', 'Prompt', 'SUT', and 'Response' columns. The output will be saved to a json lines file.
300301
"""
301302
logging.basicConfig(level=logging.DEBUG if debug else logging.INFO)
@@ -312,16 +313,18 @@ def run_job(
312313
else:
313314
ensemble = None
314315

316+
sut_uids = listify(sut_uid)
317+
315318
# Check all objects for missing secrets.
316319
secrets = load_secrets_from_config()
317320
if sut_uid:
318-
check_secrets(secrets, sut_uids=sut_uid, annotator_uids=annotator_uids)
321+
check_secrets(secrets, sut_uids=sut_uids, annotator_uids=annotator_uids)
319322
else:
320323
check_secrets(secrets, annotator_uids=annotator_uids)
321324

322325
suts = {}
323-
if sut_uid:
324-
all_suts = make_suts(sut_uid)
326+
if sut_uids:
327+
all_suts = make_suts(sut_uids)
325328
for sut in all_suts:
326329
if AcceptsTextPrompt not in sut.capabilities:
327330
raise click.BadParameter(f"{sut.uid} does not accept text prompts")
@@ -376,11 +379,11 @@ def show_progress(data):
376379
@modelgauge_cli.command()
377380
@sut_options_options
378381
@click.option(
379-
"sut_uids",
382+
"sut_uid",
380383
"-s",
381384
"--sut",
382-
help="Which SUT(s) to run.",
383-
multiple=True,
385+
help="Which SUT to run.",
386+
multiple=False,
384387
required=False,
385388
callback=validate_uid,
386389
)
@@ -423,7 +426,7 @@ def show_progress(data):
423426
type=click.Path(exists=True, path_type=pathlib.Path),
424427
)
425428
def run_csv_items(
426-
sut_uids,
429+
sut_uid,
427430
annotator_uids,
428431
ensemble,
429432
workers,
@@ -455,11 +458,14 @@ def run_csv_items(
455458
annotator_uids = annotator_uids + tuple(PRIVATE_ANNOTATOR_SET.annotators)
456459
else:
457460
ensemble = None
461+
462+
sut_uids = listify(sut_uid)
463+
458464
# Check all objects for missing secrets.
459465
secrets = load_secrets_from_config()
460466
check_secrets(secrets, sut_uids=sut_uids, annotator_uids=annotator_uids)
461467

462-
if len(sut_uids):
468+
if sut_uids:
463469
all_suts = make_suts(sut_uids)
464470
suts = {}
465471
for sut in all_suts:

src/modelgauge/secret_values.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from dataclasses import dataclass
33
from typing import Generic, List, Mapping, Optional, Sequence, Type, TypeVar
44

5-
from pydantic import BaseModel
6-
75
from modelgauge.general import get_concrete_subclasses
86

7+
from pydantic import BaseModel
8+
99

1010
def loggable_secret(secret: str) -> str:
1111
if str is None or not isinstance(secret, str):

tests/modelgauge_tests/test_cli.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from modelgauge import main
1313
from modelgauge.annotator_registry import ANNOTATORS
1414
from modelgauge.annotator_set import AnnotatorSet
15-
from modelgauge.command_line import _validate_sut_uid, check_secrets, classify_sut_uids, validate_uid
15+
from modelgauge.command_line import _validate_sut_uid, check_secrets, classify_sut_uids, listify, validate_uid
1616
from modelgauge.config import MissingSecretsFromConfig
1717
from modelgauge.data_schema import (
1818
DEFAULT_PROMPT_RESPONSE_SCHEMA as PROMPT_RESPONSE_SCHEMA,
@@ -354,7 +354,6 @@ def test_run_job_sut_only_output_name(caplog, tmp_path, prompts_file):
354354
["run-job", "--sut", "demo_yes_no", "--output-dir", tmp_path, str(prompts_file)],
355355
catch_exceptions=False,
356356
)
357-
print(result.output)
358357
assert result.exit_code == 0
359358

360359
out_path = Path(re.findall(r"\S+\.csv", caplog.text)[0])
@@ -598,3 +597,24 @@ def test_validate_uid():
598597
)
599598
== "my-fake-annotator"
600599
)
600+
601+
602+
def test_listify():
603+
assert listify("string") == [
604+
"string",
605+
]
606+
assert listify(["a", "b"]) == ["a", "b"]
607+
assert listify(("a", "b")) == ("a", "b")
608+
609+
def noop():
610+
pass
611+
612+
class Noop:
613+
pass
614+
615+
assert listify(noop) == noop
616+
assert listify(Noop) == Noop
617+
n = Noop()
618+
assert listify(n) == n
619+
620+
assert listify(None) is None

0 commit comments

Comments
 (0)