|
5 | 5 | from huggingface_hub import get_inference_endpoint, InferenceClient, InferenceEndpointStatus # type: ignore |
6 | 6 | from huggingface_hub.utils import HfHubHTTPError # type: ignore |
7 | 7 | from pydantic import BaseModel |
8 | | - |
| 8 | +from tenacity import retry, TryAgain, stop_after_attempt, wait_random_exponential |
9 | 9 | from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken |
10 | 10 | from modelgauge.prompt import TextPrompt, ChatPrompt |
11 | 11 | from modelgauge.secret_values import InjectSecret |
|
15 | 15 | from modelgauge.sut_registry import SUTS |
16 | 16 |
|
17 | 17 | HUGGING_FACE_TIMEOUT = 60 * 20 |
| 18 | +HUGGING_FACE_NUM_RETRIES = 7 |
18 | 19 |
|
19 | 20 |
|
20 | 21 | class ChatMessage(BaseModel): |
@@ -56,21 +57,30 @@ def _create_client(self) -> InferenceClient: |
56 | 57 | """Create the InferenceClient for the SUT. Must be implemented by subclasses.""" |
57 | 58 | pass |
58 | 59 |
|
| 60 | + @retry(stop=stop_after_attempt(HUGGING_FACE_NUM_RETRIES), wait=wait_random_exponential()) |
59 | 61 | def evaluate(self, request: HuggingFaceChatCompletionRequest) -> HuggingFaceChatCompletionOutput: |
60 | 62 | if self.client is None: |
61 | 63 | self.client = self._create_client() |
62 | 64 |
|
63 | | - request_dict = request.model_dump(exclude_none=True) |
64 | | - response = self.client.chat_completion(**request_dict) # type: ignore |
65 | | - # Convert to cacheable pydantic object. |
66 | | - return HuggingFaceChatCompletionOutput( |
67 | | - choices=[asdict(choice) for choice in response.choices], |
68 | | - created=response.created, |
69 | | - id=response.id, |
70 | | - model=response.model, |
71 | | - system_fingerprint=response.system_fingerprint, |
72 | | - usage=asdict(response.usage), |
73 | | - ) |
| 65 | + try: |
| 66 | + request_dict = request.model_dump(exclude_none=True) |
| 67 | + response = self.client.chat_completion(**request_dict) # type: ignore |
| 68 | + # Convert to cacheable pydantic object. |
| 69 | + return HuggingFaceChatCompletionOutput( |
| 70 | + choices=[asdict(choice) for choice in response.choices], |
| 71 | + created=response.created, |
| 72 | + id=response.id, |
| 73 | + model=response.model, |
| 74 | + system_fingerprint=response.system_fingerprint, |
| 75 | + usage=asdict(response.usage), |
| 76 | + ) |
| 77 | + except HfHubHTTPError as hf_error: |
| 78 | + if hf_error.response.status_code >= 500 or hf_error.response.status_code == 429: |
| 79 | + raise TryAgain |
| 80 | + else: |
| 81 | + raise |
| 82 | + except Exception as other_error: |
| 83 | + raise |
74 | 84 |
|
75 | 85 | def translate_response( |
76 | 86 | self, request: HuggingFaceChatCompletionRequest, response: HuggingFaceChatCompletionOutput |
|
0 commit comments