|
7 | 7 | from pydantic import BaseModel |
8 | 8 |
|
9 | 9 | from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken |
10 | | -from modelgauge.prompt import TextPrompt |
| 10 | +from modelgauge.prompt import TextPrompt, ChatPrompt |
11 | 11 | from modelgauge.secret_values import InjectSecret |
12 | 12 | from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse, TokenProbability, TopTokens |
13 | | -from modelgauge.sut_capabilities import AcceptsTextPrompt, ProducesPerTokenLogProbabilities |
| 13 | +from modelgauge.sut_capabilities import AcceptsTextPrompt, ProducesPerTokenLogProbabilities, AcceptsChatPrompt |
14 | 14 | from modelgauge.sut_decorator import modelgauge_sut |
15 | 15 | from modelgauge.sut_registry import SUTS |
16 | 16 |
|
@@ -92,7 +92,7 @@ def translate_response( |
92 | 92 | return SUTResponse(text=text, top_logprobs=logprobs) |
93 | 93 |
|
94 | 94 |
|
95 | | -@modelgauge_sut(capabilities=[AcceptsTextPrompt, ProducesPerTokenLogProbabilities]) |
| 95 | +@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt, ProducesPerTokenLogProbabilities]) |
96 | 96 | class HuggingFaceChatCompletionDedicatedSUT(BaseHuggingFaceChatCompletionSUT): |
97 | 97 | """A Hugging Face SUT that is hosted on a dedicated inference endpoint and uses the chat_completion API.""" |
98 | 98 |
|
@@ -135,6 +135,16 @@ def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> Hugg |
135 | 135 | **options.model_dump(), |
136 | 136 | ) |
137 | 137 |
|
| 138 | + def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> HuggingFaceChatCompletionRequest: |
| 139 | + logprobs = None |
| 140 | + if options.top_logprobs is not None: |
| 141 | + logprobs = True |
| 142 | + return HuggingFaceChatCompletionRequest( |
| 143 | + messages=[ChatMessage(role=p.role.lower(), content=p.text) for p in prompt.messages], |
| 144 | + logprobs=logprobs, |
| 145 | + **options.model_dump(), |
| 146 | + ) |
| 147 | + |
138 | 148 |
|
139 | 149 | @modelgauge_sut(capabilities=[AcceptsTextPrompt, ProducesPerTokenLogProbabilities]) |
140 | 150 | class HuggingFaceChatCompletionServerlessSUT(BaseHuggingFaceChatCompletionSUT): |
@@ -181,6 +191,7 @@ def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> Hugg |
181 | 191 | "llama-3-1-tulu-3-8b": "bzk", # check |
182 | 192 | "llama-3-1-tulu-3-70b": "ome", |
183 | 193 | "mistral-nemo-instruct-2407": "mgt", |
| 194 | + "mixtral-8x22b-instruct-v0-1": "kog", |
184 | 195 | "olmo-2-1124-13b-instruct": "ibo", |
185 | 196 | "olmo-2-0325-32b-instruct": "yft", |
186 | 197 | "qwen1-5-110b-chat": "gek", |
|
0 commit comments