Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 92 additions & 43 deletions plugins/huggingface/modelgauge/suts/huggingface_chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
from dataclasses import asdict
from typing import Dict, List, Optional

Expand All @@ -22,6 +23,7 @@ class ChatMessage(BaseModel):


class HuggingFaceChatCompletionRequest(BaseModel):
model: Optional[str] = None
messages: List[ChatMessage]
logprobs: bool
top_logprobs: Optional[int] = None
Expand All @@ -39,18 +41,65 @@ class HuggingFaceChatCompletionOutput(BaseModel):
usage: Optional[Dict] = None


@modelgauge_sut(capabilities=[AcceptsTextPrompt, ProducesPerTokenLogProbabilities])
class HuggingFaceChatCompletionSUT(
PromptResponseSUT[HuggingFaceChatCompletionRequest, HuggingFaceChatCompletionOutput]
class BaseHuggingFaceChatCompletionSUT(
PromptResponseSUT[HuggingFaceChatCompletionRequest, HuggingFaceChatCompletionOutput], ABC
):
"""A Hugging Face SUT that is hosted on a dedicated inference endpoint and uses the chat_completion API."""
"""A Huggingface SUT that uses the chat_completion API."""

def __init__(self, uid: str, inference_endpoint: str, token: HuggingFaceInferenceToken):
def __init__(self, uid: str, token: HuggingFaceInferenceToken):
super().__init__(uid)
self.token = token
self.inference_endpoint = inference_endpoint
self.client = None

@abstractmethod
def _create_client(self) -> InferenceClient:
"""Create the InferenceClient for the SUT. Must be implemented by subclasses."""
pass

def evaluate(self, request: HuggingFaceChatCompletionRequest) -> HuggingFaceChatCompletionOutput:
if self.client is None:
self.client = self._create_client()

request_dict = request.model_dump(exclude_none=True)
response = self.client.chat_completion(**request_dict) # type: ignore
# Convert to cacheable pydantic object.
return HuggingFaceChatCompletionOutput(
choices=[asdict(choice) for choice in response.choices],
created=response.created,
id=response.id,
model=response.model,
system_fingerprint=response.system_fingerprint,
usage=asdict(response.usage),
)

def translate_response(
self, request: HuggingFaceChatCompletionRequest, response: HuggingFaceChatCompletionOutput
) -> SUTResponse:
assert len(response.choices) == 1, f"Expected a single response message, got {len(response.choices)}."
choice = response.choices[0]
text = choice["message"]["content"]
assert text is not None
logprobs: Optional[List[TopTokens]] = None
if request.logprobs:
logprobs = []
assert choice["logprobs"] is not None, "Expected logprobs, but not returned."
lobprobs_sequence = choice["logprobs"]["content"]
for token in lobprobs_sequence:
top_tokens = []
for top_logprob in token["top_logprobs"]:
top_tokens.append(TokenProbability(token=top_logprob["token"], logprob=top_logprob["logprob"]))
logprobs.append(TopTokens(top_tokens=top_tokens))
return SUTResponse(text=text, top_logprobs=logprobs)


@modelgauge_sut(capabilities=[AcceptsTextPrompt, ProducesPerTokenLogProbabilities])
class HuggingFaceChatCompletionDedicatedSUT(BaseHuggingFaceChatCompletionSUT):
"""A Hugging Face SUT that is hosted on a dedicated inference endpoint and uses the chat_completion API."""

def __init__(self, uid: str, inference_endpoint: str, token: HuggingFaceInferenceToken):
super().__init__(uid, token)
self.inference_endpoint = inference_endpoint

def _create_client(self):
endpoint = get_inference_endpoint(self.inference_endpoint, token=self.token.value)

Expand All @@ -74,7 +123,7 @@ def _create_client(self):
f"Endpoint is not running: Please contact admin to ensure endpoint is starting or running (status: {endpoint.status})"
)

self.client = InferenceClient(base_url=endpoint.url, token=self.token.value)
return InferenceClient(base_url=endpoint.url, token=self.token.value)

def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> HuggingFaceChatCompletionRequest:
logprobs = False
Expand All @@ -86,76 +135,76 @@ def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> Hugg
**options.model_dump(),
)

def evaluate(self, request: HuggingFaceChatCompletionRequest) -> HuggingFaceChatCompletionOutput:
if self.client is None:
self._create_client()

request_dict = request.model_dump(exclude_none=True)
response = self.client.chat_completion(**request_dict) # type: ignore
# Convert to cacheable pydantic object.
return HuggingFaceChatCompletionOutput(
choices=[asdict(choice) for choice in response.choices],
created=response.created,
id=response.id,
model=response.model,
system_fingerprint=response.system_fingerprint,
usage=asdict(response.usage),
@modelgauge_sut(capabilities=[AcceptsTextPrompt, ProducesPerTokenLogProbabilities])
class HuggingFaceChatCompletionServerlessSUT(BaseHuggingFaceChatCompletionSUT):
"""A SUT hosted by an inference provider on huggingface."""

def __init__(self, uid: str, model: str, provider: str, token: HuggingFaceInferenceToken):
super().__init__(uid, token)
self.model = model
self.provider = provider

def _create_client(self):
return InferenceClient(
provider=self.provider,
api_key=self.token.value,
)

def translate_response(
self, request: HuggingFaceChatCompletionRequest, response: HuggingFaceChatCompletionOutput
) -> SUTResponse:
assert len(response.choices) == 1, f"Expected a single response message, got {len(response.choices)}."
choice = response.choices[0]
text = choice["message"]["content"]
assert text is not None
logprobs: Optional[List[TopTokens]] = None
if request.logprobs:
logprobs = []
assert choice["logprobs"] is not None, "Expected logprobs, but not returned."
lobprobs_sequence = choice["logprobs"]["content"]
for token in lobprobs_sequence:
top_tokens = []
for top_logprob in token["top_logprobs"]:
top_tokens.append(TokenProbability(token=top_logprob["token"], logprob=top_logprob["logprob"]))
logprobs.append(TopTokens(top_tokens=top_tokens))
return SUTResponse(text=text, top_logprobs=logprobs)
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> HuggingFaceChatCompletionRequest:
logprobs = False
if options.top_logprobs is not None:
logprobs = True
return HuggingFaceChatCompletionRequest(
model=self.model,
messages=[ChatMessage(role="user", content=prompt.text)],
logprobs=logprobs,
**options.model_dump(),
)


HF_SECRET = InjectSecret(HuggingFaceInferenceToken)

SUTS.register(
HuggingFaceChatCompletionSUT,
HuggingFaceChatCompletionDedicatedSUT,
"gemma-2-9b-it-hf",
"gemma-2-9b-it-plf",
HF_SECRET,
)

SUTS.register(
HuggingFaceChatCompletionSUT,
HuggingFaceChatCompletionDedicatedSUT,
"mistral-nemo-instruct-2407-hf",
"mistral-nemo-instruct-2407-mgt",
HF_SECRET,
)

SUTS.register(
HuggingFaceChatCompletionSUT,
HuggingFaceChatCompletionDedicatedSUT,
"nvidia-llama-3-1-nemotron-nano-8b-v1",
"llama-3-1-nemotron-nano-8b-v-uhu",
HF_SECRET,
)


SUTS.register(
HuggingFaceChatCompletionSUT,
HuggingFaceChatCompletionDedicatedSUT,
"qwen2-5-7b-instruct-hf",
"qwen2-5-7b-instruct-hgy",
HF_SECRET,
)

SUTS.register(
HuggingFaceChatCompletionSUT,
HuggingFaceChatCompletionDedicatedSUT,
"olmo-2-0325-32b-instruct-hf",
"olmo-2-0325-32b-instruct-yft",
HF_SECRET,
)

SUTS.register(
HuggingFaceChatCompletionServerlessSUT,
"google-gemma-3-27b-it-hf-nebius",
"google/gemma-3-27b-it",
"nebius",
HF_SECRET,
)
2 changes: 1 addition & 1 deletion plugins/huggingface/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ packages = [{include = "modelgauge"}]

[tool.poetry.dependencies]
python = "^3.10"
huggingface-hub = "^0.26.3"
huggingface-hub = "^0.29.0"

[build-system]
requires = ["poetry-core"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ChatMessage,
HuggingFaceChatCompletionOutput,
HuggingFaceChatCompletionRequest,
HuggingFaceChatCompletionSUT,
HuggingFaceChatCompletionDedicatedSUT,
)


Expand All @@ -37,7 +37,7 @@ def mock_endpoint():
def fake_sut(mock_get_inference_endpoint, mock_endpoint):
mock_get_inference_endpoint.return_value = mock_endpoint

sut = HuggingFaceChatCompletionSUT("fake_uid", "fake_endpoint", HuggingFaceInferenceToken("fake_token"))
sut = HuggingFaceChatCompletionDedicatedSUT("fake_uid", "fake_endpoint", HuggingFaceInferenceToken("fake_token"))
return sut


Expand Down
Loading