Skip to content

Commit e09639f

Browse files
authored
add huggingface-specific retry logic (#1247)
* add huggingface-specific retry logic * add test for retry * also retry rate-limit errors * removed unused constant that was set for a test approach that didn't pan out * remove unused import
1 parent d3ae678 commit e09639f

2 files changed

Lines changed: 32 additions & 12 deletions

File tree

src/modelgauge/suts/huggingface_chat_completion.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from huggingface_hub import get_inference_endpoint, InferenceClient, InferenceEndpointStatus # type: ignore
66
from huggingface_hub.utils import HfHubHTTPError # type: ignore
77
from pydantic import BaseModel
8-
8+
from tenacity import retry, TryAgain, stop_after_attempt, wait_random_exponential
99
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
1010
from modelgauge.prompt import TextPrompt, ChatPrompt
1111
from modelgauge.secret_values import InjectSecret
@@ -15,6 +15,7 @@
1515
from modelgauge.sut_registry import SUTS
1616

1717
HUGGING_FACE_TIMEOUT = 60 * 20
18+
HUGGING_FACE_NUM_RETRIES = 7
1819

1920

2021
class ChatMessage(BaseModel):
@@ -56,21 +57,30 @@ def _create_client(self) -> InferenceClient:
5657
"""Create the InferenceClient for the SUT. Must be implemented by subclasses."""
5758
pass
5859

60+
@retry(stop=stop_after_attempt(HUGGING_FACE_NUM_RETRIES), wait=wait_random_exponential())
5961
def evaluate(self, request: HuggingFaceChatCompletionRequest) -> HuggingFaceChatCompletionOutput:
6062
if self.client is None:
6163
self.client = self._create_client()
6264

63-
request_dict = request.model_dump(exclude_none=True)
64-
response = self.client.chat_completion(**request_dict) # type: ignore
65-
# Convert to cacheable pydantic object.
66-
return HuggingFaceChatCompletionOutput(
67-
choices=[asdict(choice) for choice in response.choices],
68-
created=response.created,
69-
id=response.id,
70-
model=response.model,
71-
system_fingerprint=response.system_fingerprint,
72-
usage=asdict(response.usage),
73-
)
65+
try:
66+
request_dict = request.model_dump(exclude_none=True)
67+
response = self.client.chat_completion(**request_dict) # type: ignore
68+
# Convert to cacheable pydantic object.
69+
return HuggingFaceChatCompletionOutput(
70+
choices=[asdict(choice) for choice in response.choices],
71+
created=response.created,
72+
id=response.id,
73+
model=response.model,
74+
system_fingerprint=response.system_fingerprint,
75+
usage=asdict(response.usage),
76+
)
77+
except HfHubHTTPError as hf_error:
78+
if hf_error.response.status_code >= 500 or hf_error.response.status_code == 429:
79+
raise TryAgain
80+
else:
81+
raise
82+
except Exception as other_error:
83+
raise
7484

7585
def translate_response(
7686
self, request: HuggingFaceChatCompletionRequest, response: HuggingFaceChatCompletionOutput

tests/modelgauge_tests/sut_tests/test_huggingface_chat_completion.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import Mock, patch
33

44
import pytest
5+
56
from huggingface_hub import (
67
ChatCompletionOutput,
78
ChatCompletionOutputComplete,
@@ -13,6 +14,7 @@
1314
InferenceEndpointStatus,
1415
) # type: ignore
1516
from huggingface_hub.utils import HfHubHTTPError # type: ignore
17+
from tenacity import RetryError, stop_after_attempt, wait_none
1618

1719
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
1820
from modelgauge.prompt import TextPrompt, ChatPrompt, ChatRole
@@ -355,3 +357,11 @@ def test_huggingface_chat_completion_translate_response_with_logprobs(fake_sut):
355357
),
356358
],
357359
)
360+
361+
362+
def test_huggingface_evaluate_retries(fake_sut, monkeypatch):
363+
request = _make_sut_request()
364+
monkeypatch.setattr(fake_sut.evaluate.retry, "stop", stop_after_attempt(1))
365+
monkeypatch.setattr(fake_sut.evaluate.retry, "wait", wait_none())
366+
with pytest.raises(RetryError):
367+
fake_sut.evaluate(request)

0 commit comments

Comments
 (0)