1+ from abc import ABC , abstractmethod
12from dataclasses import asdict
23from typing import Dict , List , Optional
34
@@ -22,6 +23,7 @@ class ChatMessage(BaseModel):
2223
2324
2425class HuggingFaceChatCompletionRequest (BaseModel ):
26+ model : Optional [str ] = None
2527 messages : List [ChatMessage ]
2628 logprobs : bool
2729 top_logprobs : Optional [int ] = None
@@ -39,18 +41,65 @@ class HuggingFaceChatCompletionOutput(BaseModel):
3941 usage : Optional [Dict ] = None
4042
4143
42- @modelgauge_sut (capabilities = [AcceptsTextPrompt , ProducesPerTokenLogProbabilities ])
43- class HuggingFaceChatCompletionSUT (
44- PromptResponseSUT [HuggingFaceChatCompletionRequest , HuggingFaceChatCompletionOutput ]
44+ class BaseHuggingFaceChatCompletionSUT (
45+ PromptResponseSUT [HuggingFaceChatCompletionRequest , HuggingFaceChatCompletionOutput ], ABC
4546):
46- """A Hugging Face SUT that is hosted on a dedicated inference endpoint and uses the chat_completion API."""
47+ """A Huggingface SUT that uses the chat_completion API."""
4748
48- def __init__ (self , uid : str , inference_endpoint : str , token : HuggingFaceInferenceToken ):
49+ def __init__ (self , uid : str , token : HuggingFaceInferenceToken ):
4950 super ().__init__ (uid )
5051 self .token = token
51- self .inference_endpoint = inference_endpoint
5252 self .client = None
5353
54+ @abstractmethod
55+ def _create_client (self ) -> InferenceClient :
56+ """Create the InferenceClient for the SUT. Must be implemented by subclasses."""
57+ pass
58+
59+ def evaluate (self , request : HuggingFaceChatCompletionRequest ) -> HuggingFaceChatCompletionOutput :
60+ if self .client is None :
61+ self .client = self ._create_client ()
62+
63+ request_dict = request .model_dump (exclude_none = True )
64+ response = self .client .chat_completion (** request_dict ) # type: ignore
65+ # Convert to cacheable pydantic object.
66+ return HuggingFaceChatCompletionOutput (
67+ choices = [asdict (choice ) for choice in response .choices ],
68+ created = response .created ,
69+ id = response .id ,
70+ model = response .model ,
71+ system_fingerprint = response .system_fingerprint ,
72+ usage = asdict (response .usage ),
73+ )
74+
75+ def translate_response (
76+ self , request : HuggingFaceChatCompletionRequest , response : HuggingFaceChatCompletionOutput
77+ ) -> SUTResponse :
78+ assert len (response .choices ) == 1 , f"Expected a single response message, got { len (response .choices )} ."
79+ choice = response .choices [0 ]
80+ text = choice ["message" ]["content" ]
81+ assert text is not None
82+ logprobs : Optional [List [TopTokens ]] = None
83+ if request .logprobs :
84+ logprobs = []
85+ assert choice ["logprobs" ] is not None , "Expected logprobs, but not returned."
86+ lobprobs_sequence = choice ["logprobs" ]["content" ]
87+ for token in lobprobs_sequence :
88+ top_tokens = []
89+ for top_logprob in token ["top_logprobs" ]:
90+ top_tokens .append (TokenProbability (token = top_logprob ["token" ], logprob = top_logprob ["logprob" ]))
91+ logprobs .append (TopTokens (top_tokens = top_tokens ))
92+ return SUTResponse (text = text , top_logprobs = logprobs )
93+
94+
95+ @modelgauge_sut (capabilities = [AcceptsTextPrompt , ProducesPerTokenLogProbabilities ])
96+ class HuggingFaceChatCompletionDedicatedSUT (BaseHuggingFaceChatCompletionSUT ):
97+ """A Hugging Face SUT that is hosted on a dedicated inference endpoint and uses the chat_completion API."""
98+
99+ def __init__ (self , uid : str , inference_endpoint : str , token : HuggingFaceInferenceToken ):
100+ super ().__init__ (uid , token )
101+ self .inference_endpoint = inference_endpoint
102+
54103 def _create_client (self ):
55104 endpoint = get_inference_endpoint (self .inference_endpoint , token = self .token .value )
56105
@@ -74,7 +123,7 @@ def _create_client(self):
74123 f"Endpoint is not running: Please contact admin to ensure endpoint is starting or running (status: { endpoint .status } )"
75124 )
76125
77- self . client = InferenceClient (base_url = endpoint .url , token = self .token .value )
126+ return InferenceClient (base_url = endpoint .url , token = self .token .value )
78127
79128 def translate_text_prompt (self , prompt : TextPrompt , options : SUTOptions ) -> HuggingFaceChatCompletionRequest :
80129 logprobs = False
@@ -86,76 +135,76 @@ def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> Hugg
86135 ** options .model_dump (),
87136 )
88137
89- def evaluate (self , request : HuggingFaceChatCompletionRequest ) -> HuggingFaceChatCompletionOutput :
90- if self .client is None :
91- self ._create_client ()
92138
93- request_dict = request .model_dump (exclude_none = True )
94- response = self .client .chat_completion (** request_dict ) # type: ignore
95- # Convert to cacheable pydantic object.
96- return HuggingFaceChatCompletionOutput (
97- choices = [asdict (choice ) for choice in response .choices ],
98- created = response .created ,
99- id = response .id ,
100- model = response .model ,
101- system_fingerprint = response .system_fingerprint ,
102- usage = asdict (response .usage ),
139+ @modelgauge_sut (capabilities = [AcceptsTextPrompt , ProducesPerTokenLogProbabilities ])
140+ class HuggingFaceChatCompletionServerlessSUT (BaseHuggingFaceChatCompletionSUT ):
141+ """A SUT hosted by an inference provider on huggingface."""
142+
143+ def __init__ (self , uid : str , model : str , provider : str , token : HuggingFaceInferenceToken ):
144+ super ().__init__ (uid , token )
145+ self .model = model
146+ self .provider = provider
147+
148+ def _create_client (self ):
149+ return InferenceClient (
150+ provider = self .provider ,
151+ api_key = self .token .value ,
103152 )
104153
105- def translate_response (
106- self , request : HuggingFaceChatCompletionRequest , response : HuggingFaceChatCompletionOutput
107- ) -> SUTResponse :
108- assert len (response .choices ) == 1 , f"Expected a single response message, got { len (response .choices )} ."
109- choice = response .choices [0 ]
110- text = choice ["message" ]["content" ]
111- assert text is not None
112- logprobs : Optional [List [TopTokens ]] = None
113- if request .logprobs :
114- logprobs = []
115- assert choice ["logprobs" ] is not None , "Expected logprobs, but not returned."
116- lobprobs_sequence = choice ["logprobs" ]["content" ]
117- for token in lobprobs_sequence :
118- top_tokens = []
119- for top_logprob in token ["top_logprobs" ]:
120- top_tokens .append (TokenProbability (token = top_logprob ["token" ], logprob = top_logprob ["logprob" ]))
121- logprobs .append (TopTokens (top_tokens = top_tokens ))
122- return SUTResponse (text = text , top_logprobs = logprobs )
154+ def translate_text_prompt (self , prompt : TextPrompt , options : SUTOptions ) -> HuggingFaceChatCompletionRequest :
155+ logprobs = False
156+ if options .top_logprobs is not None :
157+ logprobs = True
158+ return HuggingFaceChatCompletionRequest (
159+ model = self .model ,
160+ messages = [ChatMessage (role = "user" , content = prompt .text )],
161+ logprobs = logprobs ,
162+ ** options .model_dump (),
163+ )
123164
124165
125166HF_SECRET = InjectSecret (HuggingFaceInferenceToken )
126167
127168SUTS .register (
128- HuggingFaceChatCompletionSUT ,
169+ HuggingFaceChatCompletionDedicatedSUT ,
129170 "gemma-2-9b-it-hf" ,
130171 "gemma-2-9b-it-plf" ,
131172 HF_SECRET ,
132173)
133174
134175SUTS .register (
135- HuggingFaceChatCompletionSUT ,
176+ HuggingFaceChatCompletionDedicatedSUT ,
136177 "mistral-nemo-instruct-2407-hf" ,
137178 "mistral-nemo-instruct-2407-mgt" ,
138179 HF_SECRET ,
139180)
140181
141182SUTS .register (
142- HuggingFaceChatCompletionSUT ,
183+ HuggingFaceChatCompletionDedicatedSUT ,
143184 "nvidia-llama-3-1-nemotron-nano-8b-v1" ,
144185 "llama-3-1-nemotron-nano-8b-v-uhu" ,
145186 HF_SECRET ,
146187)
147188
148189
149190SUTS .register (
150- HuggingFaceChatCompletionSUT ,
191+ HuggingFaceChatCompletionDedicatedSUT ,
151192 "qwen2-5-7b-instruct-hf" ,
152193 "qwen2-5-7b-instruct-hgy" ,
153194 HF_SECRET ,
154195)
155196
156197SUTS .register (
157- HuggingFaceChatCompletionSUT ,
198+ HuggingFaceChatCompletionDedicatedSUT ,
158199 "olmo-2-0325-32b-instruct-hf" ,
159200 "olmo-2-0325-32b-instruct-yft" ,
160201 HF_SECRET ,
161202)
203+
204+ SUTS .register (
205+ HuggingFaceChatCompletionServerlessSUT ,
206+ "google-gemma-3-27b-it-hf-nebius" ,
207+ "google/gemma-3-27b-it" ,
208+ "nebius" ,
209+ HF_SECRET ,
210+ )
0 commit comments