Skip to content

Commit 2556849

Browse files
authored
rename main and run to cli for consistency (#1148)
1 parent 91443dc commit 2556849

7 files changed

Lines changed: 30 additions & 34 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ amazon = ["modelgauge_amazon"]
117117
all_plugins = ["modelgauge_anthropic", "modelgauge_azure", "modelgauge_baseten", "modelgauge_demo_plugin", "modelgauge_nvidia", "modelgauge_perspective_api", "modelgauge_google", "modelgauge_vertexai", "modelgauge_mistral", "modelgauge_amazon"]
118118

119119
[tool.poetry.scripts]
120-
modelbench = "modelbench.run:cli"
121-
modelgauge = "modelgauge.main:main"
120+
modelbench = "modelbench.cli:cli"
121+
modelgauge = "modelgauge.cli:cli"
122122

123123
[tool.pytest.ini_options]
124124
addopts = [
Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
create_sut_options,
1717
display_header,
1818
display_list_item,
19-
modelgauge_cli,
19+
cli,
2020
sut_options_options,
2121
validate_uid,
2222
)
@@ -38,7 +38,7 @@
3838
logger = logging.getLogger(__name__)
3939

4040

41-
@modelgauge_cli.command(name="list")
41+
@cli.command(name="list")
4242
@LOCAL_PLUGIN_DIR_OPTION
4343
def list_command() -> None:
4444
"""Overview of Plugins, Tests, and SUTs."""
@@ -85,7 +85,7 @@ def format_missing_secrets(missing):
8585
click.echo()
8686

8787

88-
@modelgauge_cli.command()
88+
@cli.command()
8989
@LOCAL_PLUGIN_DIR_OPTION
9090
def list_tests() -> None:
9191
"""List details about all registered tests."""
@@ -94,7 +94,7 @@ def list_tests() -> None:
9494
_display_factory_entry(test_uid, test_entry, secrets)
9595

9696

97-
@modelgauge_cli.command()
97+
@cli.command()
9898
@LOCAL_PLUGIN_DIR_OPTION
9999
def list_suts():
100100
"""List details about all registered SUTs (System Under Test)."""
@@ -103,7 +103,7 @@ def list_suts():
103103
_display_factory_entry(sut_uid, sut, secrets)
104104

105105

106-
@modelgauge_cli.command()
106+
@cli.command()
107107
@LOCAL_PLUGIN_DIR_OPTION
108108
def list_annotators():
109109
"""List details about all registered SUTs (System Under Test)."""
@@ -112,7 +112,7 @@ def list_annotators():
112112
_display_factory_entry(annotator_uid, annotator, secrets)
113113

114114

115-
@modelgauge_cli.command()
115+
@cli.command()
116116
@LOCAL_PLUGIN_DIR_OPTION
117117
def list_secrets() -> None:
118118
"""List details about secrets modelgauge might need."""
@@ -124,7 +124,7 @@ def list_secrets() -> None:
124124
display_header("No secrets used by any installed plugin.")
125125

126126

127-
@modelgauge_cli.command()
127+
@cli.command()
128128
@LOCAL_PLUGIN_DIR_OPTION
129129
@click.option("--sut", "-s", help="Which SUT to run.", required=True)
130130
@sut_options_options
@@ -162,7 +162,7 @@ def run_sut(
162162
click.echo(f"Normalized response: {result.model_dump_json(indent=2)}\n")
163163

164164

165-
@modelgauge_cli.command()
165+
@cli.command()
166166
@click.option("--test", "-t", help="Which registered TEST to run.", required=True, callback=validate_uid)
167167
@LOCAL_PLUGIN_DIR_OPTION
168168
@click.option("--sut", "-s", help="Which SUT to run.", required=True, multiple=False)
@@ -233,7 +233,7 @@ def run_test(
233233
print("Full TestRecord json written to", output_file)
234234

235235

236-
@modelgauge_cli.command()
236+
@cli.command()
237237
@sut_options_options
238238
@click.option(
239239
"sut_uid",
@@ -352,9 +352,5 @@ def show_progress(data):
352352
pipeline_runner.run(show_progress, debug)
353353

354354

355-
def main():
356-
modelgauge_cli()
357-
358-
359355
if __name__ == "__main__":
360-
main()
356+
cli()

src/modelgauge/command_line.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313

1414

1515
@click.group(name="modelgauge")
16-
def modelgauge_cli():
16+
def cli():
1717
"""Run the ModelGauge library from the command line."""
18-
# To add a command, decorate your function with @modelgauge_cli.command().
18+
# To add a command, decorate your function with @cli.command().
1919

2020
# Always create the config directory if it doesn't already exist.
2121
write_default_config()

src/modelgauge/suts/together_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import together # type: ignore
22
from collections import defaultdict
3-
from modelgauge.command_line import display_header, display_list_item, modelgauge_cli
3+
from modelgauge.command_line import display_header, display_list_item, cli
44
from modelgauge.config import load_secrets_from_config
55
from modelgauge.suts.together_client import TogetherApiKey
66

77

8-
@modelgauge_cli.command()
8+
@cli.command()
99
def list_together():
1010
"""List all models available in together.ai."""
1111

tests/modelbench_tests/test_run.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from modelbench.benchmark_runner import BenchmarkRun, BenchmarkRunner
1515
from modelbench.benchmarks import BenchmarkDefinition, BenchmarkScore, GeneralPurposeAiChatBenchmarkV1
1616
from modelbench.hazards import HazardDefinition, HazardScore, SafeHazardV1, Standards
17-
from modelbench.run import benchmark, cli, get_benchmark
17+
from modelbench.cli import benchmark, cli, get_benchmark
1818
from modelbench.scoring import ValueEstimate
1919
from modelgauge.base_test import PromptResponseTest
2020
from modelgauge.preflight import make_sut
@@ -136,18 +136,18 @@ def wrapper(*args, **kwargs):
136136
@pytest.fixture(autouse=False)
137137
def mock_run_benchmarks(self, sut, monkeypatch, tmp_path):
138138
mock = MagicMock(return_value=fake_benchmark_run(AHazard(), sut, tmp_path))
139-
monkeypatch.setattr(modelbench.run, "run_benchmarks_for_sut", mock)
139+
monkeypatch.setattr(modelbench.cli, "run_benchmarks_for_sut", mock)
140140
return mock
141141

142142
@pytest.fixture(autouse=False)
143143
def mock_score_benchmarks(self, sut, monkeypatch):
144144
mock = MagicMock(return_value=[self.mock_score(sut)])
145-
monkeypatch.setattr(modelbench.run, "score_benchmarks", mock)
145+
monkeypatch.setattr(modelbench.cli, "score_benchmarks", mock)
146146
return mock
147147

148148
@pytest.fixture(autouse=True)
149149
def do_print_summary(self, monkeypatch):
150-
monkeypatch.setattr(modelbench.run, "print_summary", MagicMock())
150+
monkeypatch.setattr(modelbench.cli, "print_summary", MagicMock())
151151

152152
@pytest.fixture
153153
def runner(self):
@@ -225,7 +225,7 @@ def test_benchmark_multiple_suts_produces_json(
225225
)
226226

227227
mock = MagicMock(return_value=[self.mock_score(sut_uid, benchmark), self.mock_score("demo_yes_no", benchmark)])
228-
monkeypatch.setattr(modelbench.run, "score_benchmarks", mock)
228+
monkeypatch.setattr(modelbench.cli, "score_benchmarks", mock)
229229

230230
result = runner.invoke(
231231
cli,

tests/modelgauge_tests/test_cli.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99
from click.testing import CliRunner, Result
1010

11-
from modelgauge import main
11+
from modelgauge import cli
1212
from modelgauge.annotator_registry import ANNOTATORS
1313
from modelgauge.annotator_set import AnnotatorSet
1414
from modelgauge.command_line import validate_uid
@@ -34,7 +34,7 @@
3434

3535
def run_cli(*args) -> Result:
3636
# noinspection PyTypeChecker
37-
result = CliRunner().invoke(main.modelgauge_cli, args, catch_exceptions=False)
37+
result = CliRunner().invoke(cli.cli, args, catch_exceptions=False)
3838
return result
3939

4040

@@ -96,7 +96,7 @@ def test_run_sut_invalid_uid():
9696
def test_run_sut_with_options(mock_translate_text_prompt):
9797
runner = CliRunner()
9898
result = runner.invoke(
99-
main.modelgauge_cli,
99+
cli.cli,
100100
[
101101
"run-sut",
102102
"--sut",
@@ -189,7 +189,7 @@ def test_run_job_sut_only_output_name(caplog, tmp_path, prompts_file):
189189
caplog.set_level(logging.INFO)
190190
runner = CliRunner()
191191
result = runner.invoke(
192-
main.modelgauge_cli,
192+
cli.cli,
193193
["run-job", "--sut", "demo_yes_no", "--output-dir", tmp_path, str(prompts_file)],
194194
catch_exceptions=False,
195195
)
@@ -210,7 +210,7 @@ def test_run_job_with_tag_output_name(caplog, tmp_path, prompts_file):
210210
caplog.set_level(logging.INFO)
211211
runner = CliRunner()
212212
result = runner.invoke(
213-
main.modelgauge_cli,
213+
cli.cli,
214214
["run-job", "--sut", "demo_yes_no", "--output-dir", tmp_path, "--tag", "test", str(prompts_file)],
215215
catch_exceptions=False,
216216
)
@@ -226,7 +226,7 @@ def test_run_job_sut_and_annotator_output_name(caplog, tmp_path, prompts_file):
226226
caplog.set_level(logging.INFO)
227227
runner = CliRunner()
228228
result = runner.invoke(
229-
main.modelgauge_cli,
229+
cli.cli,
230230
[
231231
"run-job",
232232
"--sut",
@@ -257,7 +257,7 @@ def test_run_job_annotators_only_output_name(caplog, tmp_path, prompt_responses_
257257
caplog.set_level(logging.INFO)
258258
runner = CliRunner()
259259
result = runner.invoke(
260-
main.modelgauge_cli,
260+
cli.cli,
261261
["run-job", "--annotator", "demo_annotator", "--output-dir", tmp_path, str(prompt_responses_file)],
262262
catch_exceptions=False,
263263
)
@@ -291,7 +291,7 @@ def evaluate(self, item):
291291
with patch.dict(sys.modules, {"modelgauge.private_ensemble_annotator_set": dummy_module}):
292292
runner = CliRunner()
293293
result = runner.invoke(
294-
main.modelgauge_cli,
294+
cli.cli,
295295
[
296296
"run-job",
297297
"--ensemble",
@@ -318,7 +318,7 @@ def evaluate(self, item):
318318
def test_run_missing_ensemble_raises_error(tmp_path, prompt_responses_file):
319319
runner = CliRunner()
320320
result = runner.invoke(
321-
main.modelgauge_cli,
321+
cli.cli,
322322
[
323323
"run-job",
324324
"--ensemble",

0 commit comments

Comments
 (0)