1717 create_sut_options ,
1818 display_header ,
1919 display_list_item ,
20+ listify ,
2021 make_suts ,
2122 modelgauge_cli ,
2223 sut_options_options ,
@@ -173,7 +174,7 @@ def run_sut(
173174@modelgauge_cli .command ()
174175@click .option ("--test" , "-t" , help = "Which registered TEST to run." , required = True , callback = validate_uid )
175176@LOCAL_PLUGIN_DIR_OPTION
176- @click .option ("--sut" , "-s" , help = "Which SUT to run." , required = True , callback = validate_uid )
177+ @click .option ("--sut" , "-s" , help = "Which SUT to run." , required = True , multiple = False , callback = validate_uid )
177178@DATA_DIR_OPTION
178179@MAX_TEST_ITEMS_OPTION
179180@click .option (
@@ -214,25 +215,25 @@ def run_test(
214215 sut ,
215216 ]
216217 )
217- sut_obj = suts [0 ]
218+ sut_instance = suts [0 ]
218219
219220 # Current this only knows how to do prompt response, so assert that is what we have.
220- assert isinstance (sut_obj , PromptResponseSUT )
221+ assert isinstance (sut_instance , PromptResponseSUT )
221222 assert isinstance (test_obj , PromptResponseTest )
222223
223- annotator_objs = []
224+ annotator_instances = []
224225 for annotator in test_obj .get_annotators ():
225226 obj = ANNOTATORS .make_instance (annotator , secrets = secrets )
226227 assert isinstance (obj , CompletionAnnotator )
227- annotator_objs .append (obj )
228+ annotator_instances .append (obj )
228229
229230 if output_file is None :
230231 os .makedirs ("output" , exist_ok = True )
231232 output_file = os .path .join ("output" , normalize_filename (f"record_for_{ test } _{ sut } .json" ))
232233 test_record = run_prompt_response_test (
233234 test_obj ,
234- sut_obj ,
235- annotator_objs ,
235+ sut_instance ,
236+ annotator_instances ,
236237 data_dir ,
237238 max_test_items ,
238239 use_caching = not no_caching ,
@@ -253,7 +254,7 @@ def run_test(
253254 "-s" ,
254255 "--sut" ,
255256 help = "Which SUT to run." ,
256- multiple = True ,
257+ multiple = False ,
257258 required = False ,
258259 callback = validate_uid ,
259260)
@@ -295,7 +296,7 @@ def run_job(
295296):
296297 """Run rows in a CSV through (a) SUT(s) and/or a set of annotators.
297298
298- If running SUTs , the file must have 'UID' and 'Text' columns. The output will be saved to a CSV file.
299+ If running a SUT , the file must have 'UID' and 'Text' columns. The output will be saved to a CSV file.
299300 If running ONLY annotators, the file must have 'UID', 'Prompt', 'SUT', and 'Response' columns. The output will be saved to a json lines file.
300301 """
301302 logging .basicConfig (level = logging .DEBUG if debug else logging .INFO )
@@ -312,16 +313,18 @@ def run_job(
312313 else :
313314 ensemble = None
314315
316+ sut_uids = listify (sut_uid )
317+
315318 # Check all objects for missing secrets.
316319 secrets = load_secrets_from_config ()
317320 if sut_uid :
318- check_secrets (secrets , sut_uids = sut_uid , annotator_uids = annotator_uids )
321+ check_secrets (secrets , sut_uids = sut_uids , annotator_uids = annotator_uids )
319322 else :
320323 check_secrets (secrets , annotator_uids = annotator_uids )
321324
322325 suts = {}
323- if sut_uid :
324- all_suts = make_suts (sut_uid )
326+ if sut_uids :
327+ all_suts = make_suts (sut_uids )
325328 for sut in all_suts :
326329 if AcceptsTextPrompt not in sut .capabilities :
327330 raise click .BadParameter (f"{ sut .uid } does not accept text prompts" )
@@ -376,11 +379,11 @@ def show_progress(data):
376379@modelgauge_cli .command ()
377380@sut_options_options
378381@click .option (
379- "sut_uids " ,
382+ "sut_uid " ,
380383 "-s" ,
381384 "--sut" ,
382- help = "Which SUT(s) to run." ,
383- multiple = True ,
385+ help = "Which SUT to run." ,
386+ multiple = False ,
384387 required = False ,
385388 callback = validate_uid ,
386389)
@@ -423,7 +426,7 @@ def show_progress(data):
423426 type = click .Path (exists = True , path_type = pathlib .Path ),
424427)
425428def run_csv_items (
426- sut_uids ,
429+ sut_uid ,
427430 annotator_uids ,
428431 ensemble ,
429432 workers ,
@@ -455,11 +458,14 @@ def run_csv_items(
455458 annotator_uids = annotator_uids + tuple (PRIVATE_ANNOTATOR_SET .annotators )
456459 else :
457460 ensemble = None
461+
462+ sut_uids = listify (sut_uid )
463+
458464 # Check all objects for missing secrets.
459465 secrets = load_secrets_from_config ()
460466 check_secrets (secrets , sut_uids = sut_uids , annotator_uids = annotator_uids )
461467
462- if len ( sut_uids ) :
468+ if sut_uids :
463469 all_suts = make_suts (sut_uids )
464470 suts = {}
465471 for sut in all_suts :
0 commit comments