Skip to content

Commit a1c751d

Browse files
authored
Serverless huggingface sut (#986)
* update huggingface_hub * new sut * update huggingface to get nebius * nebius sut
1 parent 96b894d commit a1c751d

4 files changed

Lines changed: 732 additions & 683 deletions

File tree

plugins/huggingface/modelgauge/suts/huggingface_chat_completion.py

Lines changed: 92 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from abc import ABC, abstractmethod
12
from dataclasses import asdict
23
from typing import Dict, List, Optional
34

@@ -22,6 +23,7 @@ class ChatMessage(BaseModel):
2223

2324

2425
class HuggingFaceChatCompletionRequest(BaseModel):
26+
model: Optional[str] = None
2527
messages: List[ChatMessage]
2628
logprobs: bool
2729
top_logprobs: Optional[int] = None
@@ -39,18 +41,65 @@ class HuggingFaceChatCompletionOutput(BaseModel):
3941
usage: Optional[Dict] = None
4042

4143

42-
@modelgauge_sut(capabilities=[AcceptsTextPrompt, ProducesPerTokenLogProbabilities])
43-
class HuggingFaceChatCompletionSUT(
44-
PromptResponseSUT[HuggingFaceChatCompletionRequest, HuggingFaceChatCompletionOutput]
44+
class BaseHuggingFaceChatCompletionSUT(
45+
PromptResponseSUT[HuggingFaceChatCompletionRequest, HuggingFaceChatCompletionOutput], ABC
4546
):
46-
"""A Hugging Face SUT that is hosted on a dedicated inference endpoint and uses the chat_completion API."""
47+
"""A Huggingface SUT that uses the chat_completion API."""
4748

48-
def __init__(self, uid: str, inference_endpoint: str, token: HuggingFaceInferenceToken):
49+
def __init__(self, uid: str, token: HuggingFaceInferenceToken):
4950
super().__init__(uid)
5051
self.token = token
51-
self.inference_endpoint = inference_endpoint
5252
self.client = None
5353

54+
@abstractmethod
55+
def _create_client(self) -> InferenceClient:
56+
"""Create the InferenceClient for the SUT. Must be implemented by subclasses."""
57+
pass
58+
59+
def evaluate(self, request: HuggingFaceChatCompletionRequest) -> HuggingFaceChatCompletionOutput:
60+
if self.client is None:
61+
self.client = self._create_client()
62+
63+
request_dict = request.model_dump(exclude_none=True)
64+
response = self.client.chat_completion(**request_dict) # type: ignore
65+
# Convert to cacheable pydantic object.
66+
return HuggingFaceChatCompletionOutput(
67+
choices=[asdict(choice) for choice in response.choices],
68+
created=response.created,
69+
id=response.id,
70+
model=response.model,
71+
system_fingerprint=response.system_fingerprint,
72+
usage=asdict(response.usage),
73+
)
74+
75+
def translate_response(
76+
self, request: HuggingFaceChatCompletionRequest, response: HuggingFaceChatCompletionOutput
77+
) -> SUTResponse:
78+
assert len(response.choices) == 1, f"Expected a single response message, got {len(response.choices)}."
79+
choice = response.choices[0]
80+
text = choice["message"]["content"]
81+
assert text is not None
82+
logprobs: Optional[List[TopTokens]] = None
83+
if request.logprobs:
84+
logprobs = []
85+
assert choice["logprobs"] is not None, "Expected logprobs, but not returned."
86+
lobprobs_sequence = choice["logprobs"]["content"]
87+
for token in lobprobs_sequence:
88+
top_tokens = []
89+
for top_logprob in token["top_logprobs"]:
90+
top_tokens.append(TokenProbability(token=top_logprob["token"], logprob=top_logprob["logprob"]))
91+
logprobs.append(TopTokens(top_tokens=top_tokens))
92+
return SUTResponse(text=text, top_logprobs=logprobs)
93+
94+
95+
@modelgauge_sut(capabilities=[AcceptsTextPrompt, ProducesPerTokenLogProbabilities])
96+
class HuggingFaceChatCompletionDedicatedSUT(BaseHuggingFaceChatCompletionSUT):
97+
"""A Hugging Face SUT that is hosted on a dedicated inference endpoint and uses the chat_completion API."""
98+
99+
def __init__(self, uid: str, inference_endpoint: str, token: HuggingFaceInferenceToken):
100+
super().__init__(uid, token)
101+
self.inference_endpoint = inference_endpoint
102+
54103
def _create_client(self):
55104
endpoint = get_inference_endpoint(self.inference_endpoint, token=self.token.value)
56105

@@ -74,7 +123,7 @@ def _create_client(self):
74123
f"Endpoint is not running: Please contact admin to ensure endpoint is starting or running (status: {endpoint.status})"
75124
)
76125

77-
self.client = InferenceClient(base_url=endpoint.url, token=self.token.value)
126+
return InferenceClient(base_url=endpoint.url, token=self.token.value)
78127

79128
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> HuggingFaceChatCompletionRequest:
80129
logprobs = False
@@ -86,76 +135,76 @@ def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> Hugg
86135
**options.model_dump(),
87136
)
88137

89-
def evaluate(self, request: HuggingFaceChatCompletionRequest) -> HuggingFaceChatCompletionOutput:
90-
if self.client is None:
91-
self._create_client()
92138

93-
request_dict = request.model_dump(exclude_none=True)
94-
response = self.client.chat_completion(**request_dict) # type: ignore
95-
# Convert to cacheable pydantic object.
96-
return HuggingFaceChatCompletionOutput(
97-
choices=[asdict(choice) for choice in response.choices],
98-
created=response.created,
99-
id=response.id,
100-
model=response.model,
101-
system_fingerprint=response.system_fingerprint,
102-
usage=asdict(response.usage),
139+
@modelgauge_sut(capabilities=[AcceptsTextPrompt, ProducesPerTokenLogProbabilities])
140+
class HuggingFaceChatCompletionServerlessSUT(BaseHuggingFaceChatCompletionSUT):
141+
"""A SUT hosted by an inference provider on huggingface."""
142+
143+
def __init__(self, uid: str, model: str, provider: str, token: HuggingFaceInferenceToken):
144+
super().__init__(uid, token)
145+
self.model = model
146+
self.provider = provider
147+
148+
def _create_client(self):
149+
return InferenceClient(
150+
provider=self.provider,
151+
api_key=self.token.value,
103152
)
104153

105-
def translate_response(
106-
self, request: HuggingFaceChatCompletionRequest, response: HuggingFaceChatCompletionOutput
107-
) -> SUTResponse:
108-
assert len(response.choices) == 1, f"Expected a single response message, got {len(response.choices)}."
109-
choice = response.choices[0]
110-
text = choice["message"]["content"]
111-
assert text is not None
112-
logprobs: Optional[List[TopTokens]] = None
113-
if request.logprobs:
114-
logprobs = []
115-
assert choice["logprobs"] is not None, "Expected logprobs, but not returned."
116-
lobprobs_sequence = choice["logprobs"]["content"]
117-
for token in lobprobs_sequence:
118-
top_tokens = []
119-
for top_logprob in token["top_logprobs"]:
120-
top_tokens.append(TokenProbability(token=top_logprob["token"], logprob=top_logprob["logprob"]))
121-
logprobs.append(TopTokens(top_tokens=top_tokens))
122-
return SUTResponse(text=text, top_logprobs=logprobs)
154+
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> HuggingFaceChatCompletionRequest:
155+
logprobs = False
156+
if options.top_logprobs is not None:
157+
logprobs = True
158+
return HuggingFaceChatCompletionRequest(
159+
model=self.model,
160+
messages=[ChatMessage(role="user", content=prompt.text)],
161+
logprobs=logprobs,
162+
**options.model_dump(),
163+
)
123164

124165

125166
HF_SECRET = InjectSecret(HuggingFaceInferenceToken)
126167

127168
SUTS.register(
128-
HuggingFaceChatCompletionSUT,
169+
HuggingFaceChatCompletionDedicatedSUT,
129170
"gemma-2-9b-it-hf",
130171
"gemma-2-9b-it-plf",
131172
HF_SECRET,
132173
)
133174

134175
SUTS.register(
135-
HuggingFaceChatCompletionSUT,
176+
HuggingFaceChatCompletionDedicatedSUT,
136177
"mistral-nemo-instruct-2407-hf",
137178
"mistral-nemo-instruct-2407-mgt",
138179
HF_SECRET,
139180
)
140181

141182
SUTS.register(
142-
HuggingFaceChatCompletionSUT,
183+
HuggingFaceChatCompletionDedicatedSUT,
143184
"nvidia-llama-3-1-nemotron-nano-8b-v1",
144185
"llama-3-1-nemotron-nano-8b-v-uhu",
145186
HF_SECRET,
146187
)
147188

148189

149190
SUTS.register(
150-
HuggingFaceChatCompletionSUT,
191+
HuggingFaceChatCompletionDedicatedSUT,
151192
"qwen2-5-7b-instruct-hf",
152193
"qwen2-5-7b-instruct-hgy",
153194
HF_SECRET,
154195
)
155196

156197
SUTS.register(
157-
HuggingFaceChatCompletionSUT,
198+
HuggingFaceChatCompletionDedicatedSUT,
158199
"olmo-2-0325-32b-instruct-hf",
159200
"olmo-2-0325-32b-instruct-yft",
160201
HF_SECRET,
161202
)
203+
204+
SUTS.register(
205+
HuggingFaceChatCompletionServerlessSUT,
206+
"google-gemma-3-27b-it-hf-nebius",
207+
"google/gemma-3-27b-it",
208+
"nebius",
209+
HF_SECRET,
210+
)

plugins/huggingface/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ packages = [{include = "modelgauge"}]
88

99
[tool.poetry.dependencies]
1010
python = "^3.10"
11-
huggingface-hub = "^0.26.3"
11+
huggingface-hub = "^0.29.0"
1212

1313
[build-system]
1414
requires = ["poetry-core"]

plugins/huggingface/tests/test_huggingface_chat_completion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
ChatMessage,
2121
HuggingFaceChatCompletionOutput,
2222
HuggingFaceChatCompletionRequest,
23-
HuggingFaceChatCompletionSUT,
23+
HuggingFaceChatCompletionDedicatedSUT,
2424
)
2525

2626

@@ -37,7 +37,7 @@ def mock_endpoint():
3737
def fake_sut(mock_get_inference_endpoint, mock_endpoint):
3838
mock_get_inference_endpoint.return_value = mock_endpoint
3939

40-
sut = HuggingFaceChatCompletionSUT("fake_uid", "fake_endpoint", HuggingFaceInferenceToken("fake_token"))
40+
sut = HuggingFaceChatCompletionDedicatedSUT("fake_uid", "fake_endpoint", HuggingFaceInferenceToken("fake_token"))
4141
return sut
4242

4343

0 commit comments

Comments
 (0)