Skip to content

Commit 40bbaba

Browse files
authored
Make together retry more. (#978)
Allow changing number of threads when running benchmarks. Speed up `retry` tests and other sleeping tests.
1 parent 43cf7aa commit 40bbaba

9 files changed

Lines changed: 18 additions & 193 deletions

File tree

src/modelbench/run.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def cli() -> None:
8484
@click.option("--json-logs", default=False, is_flag=True, help="Print only machine-readable progress reports")
8585
@click.option("sut_uids", "--sut", "-s", multiple=True, help="SUT uid(s) to run", required=True, callback=validate_uid)
8686
@click.option("--anonymize", type=int, help="Random number seed for consistent anonymization of SUTs")
87-
@click.option("--parallel", default=False, help="Obsolete flag, soon to be removed")
87+
@click.option("--threads", default=32, help="How many threads to use per stage")
8888
@click.option(
8989
"--version",
9090
"-v",
@@ -125,12 +125,10 @@ def benchmark(
125125
json_logs: bool,
126126
sut_uids: List[str],
127127
anonymize=None,
128-
parallel=False,
128+
threads=32,
129129
prompt_set="demo",
130130
evaluator="default",
131131
) -> None:
132-
if parallel:
133-
click.echo("--parallel option unnecessary; benchmarks are now always run in parallel")
134132
start_time = datetime.now(timezone.utc)
135133
if locale == "all":
136134
locales = LOCALES
@@ -143,7 +141,9 @@ def benchmark(
143141
suts = get_suts(sut_uids)
144142
benchmarks = [get_benchmark(version, l, prompt_set, evaluator) for l in locales]
145143

146-
run = run_benchmarks_for_suts(benchmarks, suts, max_instances, debug=debug, json_logs=json_logs)
144+
run = run_benchmarks_for_suts(
145+
benchmarks, suts, max_instances, debug=debug, json_logs=json_logs, thread_count=threads
146+
)
147147
benchmark_scores = score_benchmarks(run)
148148
output_dir.mkdir(exist_ok=True, parents=True)
149149
for b in benchmarks:

src/modelgauge/annotation_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def handle_item(self, item: SutInteraction):
112112

113113
class AnnotatorWorkers(CachingPipe):
114114
def __init__(self, annotators: dict[str, Annotator], workers=None, cache_path=None):
115+
self.sleep_time = 10
115116
if workers is None:
116117
workers = 8
117118
super().__init__(thread_count=workers, cache_path=cache_path)
@@ -140,7 +141,7 @@ def handle_uncached_item(self, item):
140141
logger.warning(
141142
f"Exception calling annotator {annotator_uid} on attempt {tries}: {e}\nRetrying.....", exc_info=True
142143
)
143-
time.sleep(10)
144+
time.sleep(self.sleep_time)
144145
result = annotator.translate_response(request, response)
145146
self.annotation_counts[annotator_uid] += 1
146147
return sut_interaction, annotator_uid, result

src/modelgauge/api_server.py

Lines changed: 0 additions & 116 deletions
This file was deleted.

src/modelgauge/prompt_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def handle_item(self, item):
150150

151151
class PromptSutWorkers(CachingPipe):
152152
def __init__(self, suts: dict[str, SUT], sut_options: Optional[SUTOptions] = None, workers=None, cache_path=None):
153+
self.sleep_time = 10
153154
if workers is None:
154155
workers = 8
155156
super().__init__(thread_count=workers, cache_path=cache_path)
@@ -178,7 +179,7 @@ def call_sut(self, prompt_text: TextPrompt, sut: PromptResponseSUT) -> SUTRespon
178179
break
179180
except Exception as e:
180181
logger.warning(f"Exception calling SUT {sut.uid} on attempt {tries}: {e}\nRetrying.....", exc_info=True)
181-
time.sleep(10)
182+
time.sleep(self.sleep_time)
182183
result = sut.translate_response(request, response)
183184
self.sut_response_counts[sut.uid] += 1
184185
return result

src/modelgauge/suts/together_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _retrying_post(url, headers, json_payload):
3232
"""HTTP Post with retry behavior."""
3333
session = requests.Session()
3434
retries = Retry(
35-
total=7,
35+
total=15,
3636
backoff_factor=2,
3737
status_forcelist=[
3838
408, # Request Timeout

tests/modelgauge_tests/test_annotation_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def test_annotator_worker_retries_until_success():
251251
annotator.annotate = mock
252252

253253
w = AnnotatorWorkers({"fake-annotator": annotator})
254+
w.sleep_time = 0.01
254255
sut_interaction = make_sut_interaction("1", "prompt", "sut", "response")
255256
result = w.handle_item((sut_interaction, "fake-annotator"))
256257

tests/modelgauge_tests/test_api_server.py

Lines changed: 0 additions & 67 deletions
This file was deleted.

tests/modelgauge_tests/test_prompt_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def test_prompt_sut_worker_retries_until_success(suts):
198198
prompt_with_context = TestItem(source_id="1", prompt=TextPrompt(text="a prompt"))
199199

200200
w = PromptSutWorkers(suts)
201+
w.sleep_time = 0.01
201202
result = w.handle_item((prompt_with_context, "fake1"))
202203
assert result == SutInteraction(prompt_with_context, "fake1", SUTResponse(text="a response"))
203204
assert mock.call_count == num_exceptions + 1

tests/modelgauge_tests/test_retry_decorator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from unittest.mock import patch
2+
13
import pytest
24
import time
35

@@ -28,7 +30,8 @@ def always_fail():
2830
raise KeyError("Intentional failure")
2931

3032
with pytest.raises(KeyError):
31-
always_fail()
33+
with patch("time.sleep") as patched_sleep:
34+
always_fail()
3235

3336
assert attempt_counter == BASE_RETRY_COUNT
3437

@@ -44,7 +47,8 @@ def succeed_before_base_retry_total():
4447
raise ValueError("Intentional failure")
4548
return "success"
4649

47-
assert succeed_before_base_retry_total() == "success"
50+
with patch("time.sleep") as patched_sleep:
51+
assert succeed_before_base_retry_total() == "success"
4852
assert attempt_counter == BASE_RETRY_COUNT
4953

5054

0 commit comments

Comments
 (0)