From 2b0c1f0abc77b85e682fa49e72ad92416cb98679 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 12 Jan 2026 14:07:16 -0800 Subject: [PATCH 1/5] Rename SUTOptions to generalized ModelOptions --- src/modelgauge/base_test.py | 6 +- src/modelgauge/cli.py | 7 +- src/modelgauge/model_options.py | 68 ++++++++++++++++ src/modelgauge/pipeline_runner.py | 4 +- src/modelgauge/prompt_pipeline.py | 5 +- src/modelgauge/records.py | 4 +- src/modelgauge/sut.py | 80 ++----------------- src/modelgauge/suts/anthropic_api.py | 5 +- src/modelgauge/suts/aws_bedrock_client.py | 5 +- src/modelgauge/suts/azure_client.py | 5 +- src/modelgauge/suts/baseten_api.py | 7 +- src/modelgauge/suts/demo_01_yes_no_sut.py | 7 +- .../suts/demo_02_secrets_and_options_sut.py | 9 ++- src/modelgauge/suts/demo_03_sut_with_args.py | 7 +- src/modelgauge/suts/google_genai.py | 5 +- src/modelgauge/suts/huggingface_api.py | 5 +- .../suts/huggingface_chat_completion.py | 11 +-- src/modelgauge/suts/indirect_sut.py | 5 +- src/modelgauge/suts/meta_llama_client.py | 7 +- src/modelgauge/suts/mistral_sut.py | 5 +- src/modelgauge/suts/nvidia_nim_api_client.py | 9 ++- src/modelgauge/suts/openai_client.py | 8 +- src/modelgauge/suts/together_client.py | 15 ++-- src/modelgauge/suts/vertexai_mistral_sut.py | 5 +- src/modelgauge/tests/safe_v1.py | 4 +- src/modelgauge/tests/security.py | 4 +- .../modelbench_tests/test_benchmark_runner.py | 5 +- tests/modelgauge_tests/fake_sut.py | 7 +- .../sut_tests/test_anthropic_api.py | 9 ++- .../sut_tests/test_aws_bedrock_client.py | 7 +- .../sut_tests/test_baseten_api.py | 7 +- .../sut_tests/test_google_genai.py | 9 ++- .../sut_tests/test_huggingface_api.py | 5 +- .../test_huggingface_chat_completion.py | 5 +- .../sut_tests/test_indirect_sut.py | 4 +- .../sut_tests/test_meta_llama.py | 5 +- .../sut_tests/test_mistral_sut.py | 5 +- .../sut_tests/test_nvidia_nim_api_client.py | 5 +- .../sut_tests/test_openai_client.py | 9 ++- .../sut_tests/test_together_client.py | 27 ++++--- .../sut_tests/test_vertexai_mistral_sut.py | 5 +- tests/modelgauge_tests/test_cli.py | 5 +- tests/modelgauge_tests/test_modelship_sut.py | 4 +- .../modelgauge_tests/test_object_creation.py | 7 +- .../modelgauge_tests/test_pipeline_runner.py | 6 +- .../modelgauge_tests/test_prompt_pipeline.py | 5 +- tests/modelgauge_tests/test_records.py | 7 +- tests/modelgauge_tests/test_test_decorator.py | 6 +- 48 files changed, 248 insertions(+), 208 deletions(-) create mode 100644 src/modelgauge/model_options.py diff --git a/src/modelgauge/base_test.py b/src/modelgauge/base_test.py index 7b9e99e1b..5f5c896c3 100644 --- a/src/modelgauge/base_test.py +++ b/src/modelgauge/base_test.py @@ -7,7 +7,7 @@ SUTResponseAnnotations, TestItem, ) -from modelgauge.sut import SUTOptions +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import SUTCapability from modelgauge.tracked_object import TrackedObject from modelgauge.typed_data import Typeable, TypedData @@ -27,7 +27,7 @@ class attribute `requires_sut_capabilities` as well as `initialization_record` o initialization_record: Initialization data that can be used to reconstruct a test instance. """ - _sut_options = SUTOptions() + _sut_options = ModelOptions() # Set automatically by @modelgauge_test() requires_sut_capabilities: Sequence[Type[SUTCapability]] @@ -38,7 +38,7 @@ def __init__(self, uid: str): self.initialization_record: InitializationRecord @classmethod - def sut_options(cls) -> SUTOptions: + def sut_options(cls) -> ModelOptions: """Returns the SUT options that are supplied in each test item. Concrete subclasses can override this method to specify their own SUT options.""" return cls._sut_options diff --git a/src/modelgauge/cli.py b/src/modelgauge/cli.py index 53a844d39..a473c34b3 100644 --- a/src/modelgauge/cli.py +++ b/src/modelgauge/cli.py @@ -30,7 +30,8 @@ from modelgauge.secret_values import get_all_secrets, RawSecrets from modelgauge.simple_test_runner import run_prompt_response_test from modelgauge.single_turn_prompt_response import SUTResponse, TestItem -from modelgauge.sut import PromptResponseSUT, SUTOptions +from modelgauge.sut import PromptResponseSUT +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_registry import SUTS from modelgauge.test_registry import TESTS @@ -150,7 +151,7 @@ def run_sut( """Send a prompt from the command line to a SUT.""" # TODO Consider a SUT factory that takes in a SUTDefinition and returns a SUT - options = SUTOptions.create_from_arguments(max_tokens, temp, top_p, top_k, top_logprobs) + options = ModelOptions.create_from_arguments(max_tokens, temp, top_p, top_k, top_logprobs) # Current this only knows how to do prompt response, so assert that is what we have. sut_instance = make_sut(sut) @@ -344,7 +345,7 @@ def run_job( # make sure the job has everything it needs to run secrets = load_secrets_from_config() if sut: - sut_options = SUTOptions.create_from_arguments(max_tokens, temp, top_p, top_k) + sut_options = ModelOptions.create_from_arguments(max_tokens, temp, top_p, top_k) sut_instance = make_sut(sut) if AcceptsTextPrompt not in sut_instance.capabilities: raise click.BadParameter(f"{sut} does not accept text prompts") diff --git a/src/modelgauge/model_options.py b/src/modelgauge/model_options.py new file mode 100644 index 000000000..7d10ead18 --- /dev/null +++ b/src/modelgauge/model_options.py @@ -0,0 +1,68 @@ +from typing import Optional, List + +from pydantic import BaseModel, model_validator + + +class ModelOptions(BaseModel): + """ + An exhaustive set of options that could potentially be desired by a model. + + Not all SUTs and annotators respect all options. + """ + + max_tokens: int = 100 + """Maximum number of tokens to generate (per completion)""" + + max_total_output_tokens: Optional[int] = None + """Maximum number of tokens for all generated SUT outputs, including reasoning.""" + + temperature: Optional[float] = None + """Temperature parameter that governs diversity""" + + top_k_per_token: Optional[int] = None + """Take this many highest probability candidates per token in the completion""" + + stop_sequences: Optional[List[str]] = None + """Stop generating once we hit one of these strings.""" + + top_p: Optional[float] = None + """Same from tokens that occupy this probability mass (nucleus sampling)""" + + presence_penalty: Optional[float] = None + """Penalize repetition (OpenAI & Writer only)""" + + frequency_penalty: Optional[float] = None + """Penalize repetition (OpenAI & Writer only)""" + + random: Optional[str] = None + """Used to control randomness. Expect different responses for the same + request but with different values for `random`.""" + + # Must specify SUTCapabilities for these + top_logprobs: Optional[int] = None + """If present, will request the log probabilities for this + many of the top tokens at each token position.""" + + @model_validator(mode="after") + def check_max_total_output_tokens(self): + if self.max_total_output_tokens is not None and self.max_total_output_tokens < self.max_tokens: + raise ValueError( + f"Invalid ModelOptions. max_total_output_tokens ({self.max_total_output_tokens}) must be >= max_tokens ({self.max_tokens})." + ) + return self + + @staticmethod + def create_from_arguments(max_tokens=None, temp=None, top_p=None, top_k=None, top_logprobs=None): + options = ModelOptions() + if max_tokens is not None: + options.max_tokens = max_tokens + if temp is not None: + options.temperature = temp + if top_p is not None: + options.top_p = top_p + if top_k is not None: + options.top_k_per_token = top_k + if top_logprobs is not None: + options.top_logprobs = top_logprobs + + return options diff --git a/src/modelgauge/pipeline_runner.py b/src/modelgauge/pipeline_runner.py index 7b2a92f64..1fbd01267 100644 --- a/src/modelgauge/pipeline_runner.py +++ b/src/modelgauge/pipeline_runner.py @@ -14,7 +14,7 @@ from modelgauge.pipeline import Pipeline from modelgauge.prompt_pipeline import PromptSink, PromptSource, PromptSutAssigner, PromptSutWorkers from modelgauge.ready import ReadyResponses, Readyable -from modelgauge.sut import SUTOptions +from modelgauge.model_options import ModelOptions logger = get_logger(__name__) @@ -26,7 +26,7 @@ def __init__( input_dataset, output_dir, cache_dir=None, - sut_options=SUTOptions(), + sut_options=ModelOptions(), tag=None, ): self.num_workers = num_workers diff --git a/src/modelgauge/prompt_pipeline.py b/src/modelgauge/prompt_pipeline.py index 42eb1a19f..dd5df79ac 100644 --- a/src/modelgauge/prompt_pipeline.py +++ b/src/modelgauge/prompt_pipeline.py @@ -6,7 +6,8 @@ from modelgauge.pipeline import CachingPipe, Pipe, Sink, Source from modelgauge.prompt import TextPrompt from modelgauge.single_turn_prompt_response import SUTInteraction, TestItem -from modelgauge.sut import PromptResponseSUT, SUT, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUT, SUTResponse +from modelgauge.model_options import ModelOptions logger = get_logger(__name__) @@ -31,7 +32,7 @@ def handle_item(self, item): class PromptSutWorkers(CachingPipe): - def __init__(self, suts: dict[str, SUT], sut_options: Optional[SUTOptions] = None, workers=None, cache_path=None): + def __init__(self, suts: dict[str, SUT], sut_options: Optional[ModelOptions] = None, workers=None, cache_path=None): self.sleep_time = 10 if workers is None: workers = 8 diff --git a/src/modelgauge/records.py b/src/modelgauge/records.py index b315baf50..321d70eff 100644 --- a/src/modelgauge/records.py +++ b/src/modelgauge/records.py @@ -5,7 +5,7 @@ SUTResponseAnnotations, TestItem, ) -from modelgauge.sut import SUTOptions +from modelgauge.model_options import ModelOptions from pydantic import AwareDatetime, BaseModel, Field from typing import Dict, List, Mapping @@ -37,7 +37,7 @@ class TestRecord(BaseModel): run_timestamp: AwareDatetime = Field(default_factory=current_local_datetime) test_uid: str test_initialization: InitializationRecord - sut_options: SUTOptions + sut_options: ModelOptions dependency_versions: Mapping[str, str] sut_uid: str sut_initialization: InitializationRecord diff --git a/src/modelgauge/sut.py b/src/modelgauge/sut.py index ca104e928..5531c7ad8 100644 --- a/src/modelgauge/sut.py +++ b/src/modelgauge/sut.py @@ -1,8 +1,9 @@ from abc import abstractmethod -from typing import List, Optional, Sequence, Type +from typing import Optional, Sequence, Type -from pydantic import BaseModel, model_validator +from pydantic import BaseModel +from modelgauge.model_options import ModelOptions from modelgauge.not_implemented import not_implemented from modelgauge.prompt import ChatPrompt, TextPrompt from modelgauge.ready import Readyable, ReadyResponse @@ -13,71 +14,6 @@ REFUSAL_RESPONSE = "" -class SUTOptions(BaseModel): - """ - An exhaustive set of options that could potentially be desired by a SUT. - - Not all SUTs respect all options. - """ - - max_tokens: int = 100 - """Maximum number of tokens to generate (per completion)""" - - max_total_output_tokens: Optional[int] = None - """Maximum number of tokens for all generated SUT outputs, including reasoning.""" - - temperature: Optional[float] = None - """Temperature parameter that governs diversity""" - - top_k_per_token: Optional[int] = None - """Take this many highest probability candidates per token in the completion""" - - stop_sequences: Optional[List[str]] = None - """Stop generating once we hit one of these strings.""" - - top_p: Optional[float] = None - """Same from tokens that occupy this probability mass (nucleus sampling)""" - - presence_penalty: Optional[float] = None - """Penalize repetition (OpenAI & Writer only)""" - - frequency_penalty: Optional[float] = None - """Penalize repetition (OpenAI & Writer only)""" - - random: Optional[str] = None - """Used to control randomness. Expect different responses for the same - request but with different values for `random`.""" - - # Must specify SUTCapabilities for these - top_logprobs: Optional[int] = None - """If present, will request the log probabilities for this - many of the top tokens at each token position.""" - - @model_validator(mode="after") - def check_max_total_output_tokens(self): - if self.max_total_output_tokens is not None and self.max_total_output_tokens < self.max_tokens: - raise ValueError( - f"Invalid SUTOptions. max_total_output_tokens ({self.max_total_output_tokens}) must be >= max_tokens ({self.max_tokens})." - ) - return self - - @staticmethod - def create_from_arguments(max_tokens=None, temp=None, top_p=None, top_k=None, top_logprobs=None): - options = SUTOptions() - if max_tokens is not None: - options.max_tokens = max_tokens - if temp is not None: - options.temperature = temp - if top_p is not None: - options.top_p = top_p - if top_k is not None: - options.top_k_per_token = top_k - if top_logprobs is not None: - options.top_logprobs = top_logprobs - - return options - - class TokenProbability(BaseModel): """Probability assigned to a given token.""" @@ -98,10 +34,10 @@ class SUTResponse(BaseModel): top_logprobs: Optional[Sequence[TopTokens]] = None """For each position, list the probabilities for each of the most likely tokens. - To guarantee this field is not None, the Test must specify SUTOptions.top_logprobs + To guarantee this field is not None, the Test must specify ModelOptions.top_logprobs and that it requires_sut_capabilities ProducesPerTokenLogProbabilities. SUTs that set this value must specify they have the ProducesPerTokenLogProbabilities - capability. They may conditional setting the field on on SUTOptions.top_logprobs being not None. + capability. They may conditional setting the field on on ModelOptions.top_logprobs being not None. """ @@ -127,7 +63,7 @@ def __init__(self, uid: str): _READINESS_CHECK_TEXT_PROMPT = TextPrompt(text="Why did the chicken cross the road?") -_READINESS_CHECK_SUT_OPTIONS = SUTOptions(max_tokens=20) +_READINESS_CHECK_SUT_OPTIONS = ModelOptions(max_tokens=20) class PromptResponseSUT(SUT, Readyable): @@ -142,7 +78,7 @@ def run_readiness_check(self) -> ReadyResponse: return ReadyResponse(is_ready=response.text is not None, response=response) @not_implemented - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions): + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions): """Convert the prompt + SUT options into the SUT's native representation. This method must be implemented if the SUT accepts text prompts. @@ -150,7 +86,7 @@ def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions): raise NotImplementedError(f"SUT {self.__class__.__name__} does not implement translate_text_prompt.") @not_implemented - def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions): + def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions): """Convert the prompt + SUT options into the SUT's native representation. This method must be implemented if the SUT accepts chat prompts. diff --git a/src/modelgauge/suts/anthropic_api.py b/src/modelgauge/suts/anthropic_api.py index 2ba785174..67357fa2d 100644 --- a/src/modelgauge/suts/anthropic_api.py +++ b/src/modelgauge/suts/anthropic_api.py @@ -11,7 +11,8 @@ from modelgauge.general import APIException from modelgauge.prompt import ChatRole, TextPrompt from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -54,7 +55,7 @@ def _load_client(self) -> Anthropic: max_retries=7, ) - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> AnthropicRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> AnthropicRequest: messages = [OpenAIChatMessage(content=prompt.text, role=_ROLE_MAP[ChatRole.user])] return AnthropicRequest( model=self.model, diff --git a/src/modelgauge/suts/aws_bedrock_client.py b/src/modelgauge/suts/aws_bedrock_client.py index e16cbca32..8af838b9a 100644 --- a/src/modelgauge/suts/aws_bedrock_client.py +++ b/src/modelgauge/suts/aws_bedrock_client.py @@ -9,7 +9,8 @@ from modelgauge.prompt import TextPrompt from modelgauge.retry_decorator import retry from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -122,7 +123,7 @@ def _load_client(self): aws_secret_access_key=self.secret_access_key, ) - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> BedrockRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> BedrockRequest: inference_config = BedrockRequest.InferenceConfig( maxTokens=options.max_tokens, temperature=options.temperature, diff --git a/src/modelgauge/suts/azure_client.py b/src/modelgauge/suts/azure_client.py index b0216dfa1..040a1c573 100644 --- a/src/modelgauge/suts/azure_client.py +++ b/src/modelgauge/suts/azure_client.py @@ -8,7 +8,8 @@ from modelgauge.general import APIException from modelgauge.prompt import TextPrompt from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -105,7 +106,7 @@ def __init__(self, uid: str, endpoint_url: str, api_key: AzureApiKey): self.endpoint_url = endpoint_url self.api_key = api_key.value - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> AzureChatRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> AzureChatRequest: messages = [AzureChatRequest.Message(content=prompt.text, role="user")] return AzureChatRequest( messages=messages, diff --git a/src/modelgauge/suts/baseten_api.py b/src/modelgauge/suts/baseten_api.py index ecd938250..9064105f7 100644 --- a/src/modelgauge/suts/baseten_api.py +++ b/src/modelgauge/suts/baseten_api.py @@ -6,7 +6,8 @@ from modelgauge.prompt import TextPrompt from modelgauge.retry_decorator import retry from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -94,7 +95,7 @@ def translate_response(self, request: BasetenChatRequest, response: BasetenRespo @modelgauge_sut(capabilities=[AcceptsTextPrompt]) class BasetenPromptSUT(BasetenSUT): - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> BasetenChatRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> BasetenChatRequest: return BasetenChatPromptRequest( model=self.model, prompt=prompt.text, stream=False, max_tokens=options.max_tokens ) @@ -102,7 +103,7 @@ def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> Base @modelgauge_sut(capabilities=[AcceptsTextPrompt]) class BasetenMessagesSUT(BasetenSUT): - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> BasetenChatRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> BasetenChatRequest: return BasetenChatMessagesRequest( model=self.model, messages=[BasetenChatMessage(role="user", content=prompt.text)], diff --git a/src/modelgauge/suts/demo_01_yes_no_sut.py b/src/modelgauge/suts/demo_01_yes_no_sut.py index 1587ae1c6..b884f433b 100644 --- a/src/modelgauge/suts/demo_01_yes_no_sut.py +++ b/src/modelgauge/suts/demo_01_yes_no_sut.py @@ -1,6 +1,7 @@ from modelgauge.prompt import ChatPrompt, TextPrompt from modelgauge.prompt_formatting import format_chat -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -24,10 +25,10 @@ class DemoYesNoResponse(BaseModel): class DemoYesNoSUT(PromptResponseSUT): """This SUT demonstrates the bare minimum behavior of a SUT: Use the input Prompt to determine the response.""" - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> DemoYesNoRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> DemoYesNoRequest: return DemoYesNoRequest(text=prompt.text) - def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> DemoYesNoRequest: + def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> DemoYesNoRequest: return DemoYesNoRequest(text=format_chat(prompt)) def evaluate(self, request: DemoYesNoRequest) -> DemoYesNoResponse: diff --git a/src/modelgauge/suts/demo_02_secrets_and_options_sut.py b/src/modelgauge/suts/demo_02_secrets_and_options_sut.py index a373387b9..405195623 100644 --- a/src/modelgauge/suts/demo_02_secrets_and_options_sut.py +++ b/src/modelgauge/suts/demo_02_secrets_and_options_sut.py @@ -1,7 +1,8 @@ import random from modelgauge.prompt import ChatPrompt, TextPrompt from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -46,14 +47,14 @@ def __init__(self, uid: str, api_key: DemoApiKey): def _load_client(self) -> "RandomWordsClient": return RandomWordsClient(api_key=self.api_key) - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> DemoRandomWordsRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> DemoRandomWordsRequest: return self._translate(prompt.text, options) - def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> DemoRandomWordsRequest: + def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> DemoRandomWordsRequest: # All we care about are the words in the chat history, not who said them. return self._translate(_words_in_chat(prompt), options) - def _translate(self, text, options: SUTOptions) -> DemoRandomWordsRequest: + def _translate(self, text, options: ModelOptions) -> DemoRandomWordsRequest: return DemoRandomWordsRequest( source_text=text, # Copy over the requested options. diff --git a/src/modelgauge/suts/demo_03_sut_with_args.py b/src/modelgauge/suts/demo_03_sut_with_args.py index fc3742742..61bec03af 100644 --- a/src/modelgauge/suts/demo_03_sut_with_args.py +++ b/src/modelgauge/suts/demo_03_sut_with_args.py @@ -1,5 +1,6 @@ from modelgauge.prompt import ChatPrompt, TextPrompt -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -26,10 +27,10 @@ def __init__(self, uid: str, response_text: str): super().__init__(uid) self.response_text = response_text - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> DemoConstantRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> DemoConstantRequest: return DemoConstantRequest(configured_response=self.response_text) - def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> DemoConstantRequest: + def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> DemoConstantRequest: return DemoConstantRequest(configured_response=self.response_text) def evaluate(self, request: DemoConstantRequest) -> DemoConstantResponse: diff --git a/src/modelgauge/suts/google_genai.py b/src/modelgauge/suts/google_genai.py index 39a0dce04..d97177264 100644 --- a/src/modelgauge/suts/google_genai.py +++ b/src/modelgauge/suts/google_genai.py @@ -19,7 +19,8 @@ from modelgauge.prompt import TextPrompt from modelgauge.retry_decorator import retry from modelgauge.secret_values import InjectSecret, loggable_secret, RequiredSecret, SecretDescription -from modelgauge.sut import REFUSAL_RESPONSE, PromptResponseSUT, SUTOptions, SUTResponse # usort: skip +from modelgauge.sut import REFUSAL_RESPONSE, PromptResponseSUT, SUTResponse # usort: skip +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -70,7 +71,7 @@ def _load_client(self) -> genai.Client: logger.exception(f"Failed to load genai.Client with api_key='{loggable_secret(self.api_key)}'") raise - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> GenAiRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> GenAiRequest: optional = {} if not self.reasoning: optional["thinking_config"] = ThinkingConfig( diff --git a/src/modelgauge/suts/huggingface_api.py b/src/modelgauge/suts/huggingface_api.py index 254cc1e0c..e35c6cfa4 100644 --- a/src/modelgauge/suts/huggingface_api.py +++ b/src/modelgauge/suts/huggingface_api.py @@ -5,7 +5,8 @@ from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken from modelgauge.prompt import TextPrompt from modelgauge.secret_values import InjectSecret -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -36,7 +37,7 @@ def __init__(self, uid: str, api_url: str, token: HuggingFaceInferenceToken): self.token = token.value self.api_url = api_url - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> HuggingFaceChatRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> HuggingFaceChatRequest: return HuggingFaceChatRequest( inputs=prompt.text, parameters=HuggingFaceChatParams(max_new_tokens=options.max_tokens, temperature=options.temperature), diff --git a/src/modelgauge/suts/huggingface_chat_completion.py b/src/modelgauge/suts/huggingface_chat_completion.py index 2fb359891..1e31e0508 100644 --- a/src/modelgauge/suts/huggingface_chat_completion.py +++ b/src/modelgauge/suts/huggingface_chat_completion.py @@ -12,7 +12,8 @@ from modelgauge.prompt import TextPrompt, ChatPrompt from modelgauge.retry_decorator import retry from modelgauge.secret_values import InjectSecret -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse, TokenProbability, TopTokens +from modelgauge.sut import PromptResponseSUT, SUTResponse, TokenProbability, TopTokens +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -161,7 +162,7 @@ def _optional_request_kwargs(self) -> Dict: optional_kwargs["model"] = self.model return optional_kwargs - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> HuggingFaceChatCompletionRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> HuggingFaceChatCompletionRequest: logprobs = None if options.top_logprobs is not None: logprobs = True @@ -172,7 +173,7 @@ def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> Hugg **options.model_dump(), ) - def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> HuggingFaceChatCompletionRequest: + def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> HuggingFaceChatCompletionRequest: logprobs = None if options.top_logprobs is not None: logprobs = True @@ -199,7 +200,7 @@ def _create_client(self): api_key=self.token.value, ) - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> HuggingFaceChatCompletionRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> HuggingFaceChatCompletionRequest: logprobs = None if options.top_logprobs is not None: logprobs = True @@ -210,7 +211,7 @@ def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> Hugg **options.model_dump(), ) - def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> HuggingFaceChatCompletionRequest: + def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> HuggingFaceChatCompletionRequest: messages = [] for message in prompt.messages: messages.append(ChatMessage(content=message.text, role=message.role.lower())) diff --git a/src/modelgauge/suts/indirect_sut.py b/src/modelgauge/suts/indirect_sut.py index e4153f233..e27408ff1 100644 --- a/src/modelgauge/suts/indirect_sut.py +++ b/src/modelgauge/suts/indirect_sut.py @@ -9,7 +9,8 @@ from modelgauge.prompt import TextPrompt from modelgauge.ready import ReadyResponse from modelgauge.secret_values import InjectSecret -from modelgauge.sut import PromptResponseSUT, SUTResponse, SUTOptions +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_definition import SUTDefinition @@ -52,7 +53,7 @@ def __init__(self, uid: str, model_name: str, port: int = DEFAULT_PORT): def is_ready(self) -> ReadyResponse: return ReadyResponse(True) - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> IndirectSUTRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> IndirectSUTRequest: messages = [OpenAIChatMessage(content=prompt.text, role=USER_ROLE)] return IndirectSUTRequest( request_id=self._id_generator.next(), diff --git a/src/modelgauge/suts/meta_llama_client.py b/src/modelgauge/suts/meta_llama_client.py index 559e2de77..b83221e93 100644 --- a/src/modelgauge/suts/meta_llama_client.py +++ b/src/modelgauge/suts/meta_llama_client.py @@ -8,7 +8,8 @@ from modelgauge.prompt import TextPrompt from modelgauge.retry_decorator import retry from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription -from modelgauge.sut import PromptResponseSUT, REFUSAL_RESPONSE, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, REFUSAL_RESPONSE, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -46,7 +47,7 @@ def __init__(self, uid: str, model: str, api_key: MetaLlamaApiKey): self.model = model self.client = LlamaAPIClient(api_key=api_key.value, max_retries=10, timeout=Timeout(120)) - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> MetaLlamaChatRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> MetaLlamaChatRequest: return MetaLlamaChatRequest( model=self.model, messages=[InputMessage(role="user", content=prompt.text)], @@ -82,7 +83,7 @@ def __init__(self, uid: str, model: str, api_key: MetaLlamaApiKey): self.model = model self.client = LlamaAPIClient(api_key=api_key.value, max_retries=10, timeout=Timeout(120)) - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> MetaLlamaChatRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> MetaLlamaChatRequest: return MetaLlamaChatRequest( model=self.model, messages=[InputMessage(role="user", content=prompt.text)], diff --git a/src/modelgauge/suts/mistral_sut.py b/src/modelgauge/suts/mistral_sut.py index 69c8eba9e..0ebd36b92 100644 --- a/src/modelgauge/suts/mistral_sut.py +++ b/src/modelgauge/suts/mistral_sut.py @@ -6,7 +6,8 @@ from modelgauge.prompt import TextPrompt from modelgauge.retry_decorator import retry from modelgauge.secret_values import InjectSecret -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -52,7 +53,7 @@ def client(self): self._client = MistralAIClient(self.model_name, self._api_key) return self._client - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> MistralAIRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> MistralAIRequest: args = {"model": self.model_name, "messages": [{"role": _USER_ROLE, "content": prompt.text}]} if options.temperature is not None: args["temperature"] = options.temperature diff --git a/src/modelgauge/suts/nvidia_nim_api_client.py b/src/modelgauge/suts/nvidia_nim_api_client.py index cd9a6276a..f12702f12 100644 --- a/src/modelgauge/suts/nvidia_nim_api_client.py +++ b/src/modelgauge/suts/nvidia_nim_api_client.py @@ -12,7 +12,8 @@ RequiredSecret, SecretDescription, ) -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import ( AcceptsChatPrompt, AcceptsTextPrompt, @@ -88,17 +89,17 @@ def __init__(self, uid: str, model: str, api_key: NvidiaNIMApiKey): def _load_client(self) -> OpenAI: return OpenAI(api_key=self.api_key, base_url="https://integrate.api.nvidia.com/v1") - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> OpenAIChatRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> OpenAIChatRequest: messages = [OpenAIChatMessage(content=prompt.text, role=_USER_ROLE)] return self._translate_request(messages, options) - def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> OpenAIChatRequest: + def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> OpenAIChatRequest: messages = [] for message in prompt.messages: messages.append(OpenAIChatMessage(content=message.text, role=_ROLE_MAP[message.role])) return self._translate_request(messages, options) - def _translate_request(self, messages: List[OpenAIChatMessage], options: SUTOptions): + def _translate_request(self, messages: List[OpenAIChatMessage], options: ModelOptions): optional_kwargs: Dict[str, Any] = {} return OpenAIChatRequest( messages=messages, diff --git a/src/modelgauge/suts/openai_client.py b/src/modelgauge/suts/openai_client.py index e5216bc34..f373a48a3 100644 --- a/src/modelgauge/suts/openai_client.py +++ b/src/modelgauge/suts/openai_client.py @@ -17,11 +17,11 @@ from modelgauge.secret_values import InjectSecret from modelgauge.sut import ( PromptResponseSUT, - SUTOptions, SUTResponse, TokenProbability, TopTokens, ) +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import ( AcceptsChatPrompt, AcceptsTextPrompt, @@ -119,17 +119,17 @@ def _load_client(self) -> OpenAI | None: else: return OpenAI(api_key=self.api_key, max_retries=7) - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> OpenAIChatRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> OpenAIChatRequest: messages = [OpenAIChatMessage(content=prompt.text, role=_USER_ROLE)] return self._translate_request(messages, options) - def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> OpenAIChatRequest: + def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> OpenAIChatRequest: messages = [] for message in prompt.messages: messages.append(OpenAIChatMessage(content=message.text, role=_ROLE_MAP[message.role])) return self._translate_request(messages, options) - def _translate_request(self, messages: List[OpenAIChatMessage], options: SUTOptions): + def _translate_request(self, messages: List[OpenAIChatMessage], options: ModelOptions): optional_kwargs: Dict[str, Any] = {} if options.top_logprobs is not None: optional_kwargs["logprobs"] = True diff --git a/src/modelgauge/suts/together_client.py b/src/modelgauge/suts/together_client.py index 3d7f386ce..11838c178 100644 --- a/src/modelgauge/suts/together_client.py +++ b/src/modelgauge/suts/together_client.py @@ -12,7 +12,8 @@ from modelgauge.prompt_formatting import format_chat from modelgauge.tokenizer import GeneralTokenizer from modelgauge.secret_values import InjectSecret -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse, TokenProbability, TopTokens +from modelgauge.sut import PromptResponseSUT, SUTResponse, TokenProbability, TopTokens +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -120,10 +121,10 @@ def __init__(self, uid: str, model, api_key: TogetherApiKey): self.model = model self.api_key = api_key.value - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> TogetherCompletionsRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> TogetherCompletionsRequest: return self._translate_request(prompt.text, options) - def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> TogetherCompletionsRequest: + def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> TogetherCompletionsRequest: return self._translate_request(format_chat(prompt, user_role=_USER_ROLE, sut_role=_ASSISTANT_ROLE), options) def _translate_request(self, text, options): @@ -220,16 +221,16 @@ def __init__(self, uid: str, model, api_key: TogetherApiKey): self.model = model self.api_key = api_key.value - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> TogetherChatRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> TogetherChatRequest: return self._translate_request([TogetherChatRequest.Message(content=prompt.text, role=_USER_ROLE)], options) - def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> TogetherChatRequest: + def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> TogetherChatRequest: messages = [] for message in prompt.messages: messages.append(TogetherChatRequest.Message(content=message.text, role=_ROLE_MAP[message.role])) return self._translate_request(messages, options) - def _translate_request(self, messages: List[TogetherChatRequest.Message], options: SUTOptions): + def _translate_request(self, messages: List[TogetherChatRequest.Message], options: ModelOptions): return TogetherChatRequest( model=self.model, messages=messages, @@ -281,7 +282,7 @@ def __init__(self, uid: str, model, api_key: TogetherApiKey): self.tokenizer = GeneralTokenizer() def _translate_request( - self, messages: List[TogetherChatRequest.Message], options: SUTOptions + self, messages: List[TogetherChatRequest.Message], options: ModelOptions ) -> TogetherThinkingChatRequest: max_tokens = options.max_total_output_tokens if max_tokens is None: diff --git a/src/modelgauge/suts/vertexai_mistral_sut.py b/src/modelgauge/suts/vertexai_mistral_sut.py index 79590d426..a76d1a1fa 100644 --- a/src/modelgauge/suts/vertexai_mistral_sut.py +++ b/src/modelgauge/suts/vertexai_mistral_sut.py @@ -2,7 +2,8 @@ from modelgauge.prompt import TextPrompt from modelgauge.secret_values import InjectSecret -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -73,7 +74,7 @@ def client(self) -> VertexAIClient: ) return self._client - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> VertexAIMistralRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> VertexAIMistralRequest: args = { "model": f"{self.model_name}-{self.model_version}", "messages": [{"role": _USER_ROLE, "content": prompt.text}], diff --git a/src/modelgauge/tests/safe_v1.py b/src/modelgauge/tests/safe_v1.py index 449ed70ad..e6366a2ed 100644 --- a/src/modelgauge/tests/safe_v1.py +++ b/src/modelgauge/tests/safe_v1.py @@ -25,7 +25,7 @@ TestItem, convert_annotation_to_measurement, ) -from modelgauge.sut import SUTOptions +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.test_decorator import modelgauge_test from modelgauge.test_registry import TESTS @@ -99,7 +99,7 @@ class BaseSafeTestVersion1(PromptResponseTest, ABC): """ hazards = Hazards() - _sut_options = SUTOptions( + _sut_options = ModelOptions( max_tokens=3000, max_total_output_tokens=10000, # For reasoning SUTs. temperature=0.01, diff --git a/src/modelgauge/tests/security.py b/src/modelgauge/tests/security.py index d6e57a8ef..509518e8f 100644 --- a/src/modelgauge/tests/security.py +++ b/src/modelgauge/tests/security.py @@ -27,7 +27,7 @@ TestItem, convert_annotation_to_measurement, ) -from modelgauge.sut import SUTOptions +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.tests.safe_v1 import Hazards from modelgauge.test_decorator import modelgauge_test @@ -43,7 +43,7 @@ class BaseSecurityTest(PromptResponseTest, ABC): hazards = Hazards() prompt_sets: dict persona: str - _sut_options = SUTOptions( + _sut_options = ModelOptions( max_tokens=3000, max_total_output_tokens=10000, # For reasoning SUTs. temperature=0.01, diff --git a/tests/modelbench_tests/test_benchmark_runner.py b/tests/modelbench_tests/test_benchmark_runner.py index 3b6f53049..6ddff0afa 100644 --- a/tests/modelbench_tests/test_benchmark_runner.py +++ b/tests/modelbench_tests/test_benchmark_runner.py @@ -16,7 +16,8 @@ from modelgauge.prompt import TextPrompt from modelgauge.secret_values import get_all_secrets, RawSecrets from modelgauge.single_turn_prompt_response import TestItem -from modelgauge.sut import SUTOptions, SUTResponse +from modelgauge.sut import SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_registry import SUTS from modelgauge.suts.demo_01_yes_no_sut import DemoYesNoResponse from modelgauge_tests.fake_annotator import ( @@ -531,7 +532,7 @@ def test_benchmark_sut_worker_cached(self, item_from_test, a_wrapped_test, tmp_p run = self.a_run(tmp_path, suts=[a_sut]) cache = InMemoryCache() bsw = TestRunSutWorker(run, cache) - request = a_sut.translate_text_prompt(item_from_test.prompt, SUTOptions()) + request = a_sut.translate_text_prompt(item_from_test.prompt, ModelOptions()) key = bsw.make_cache_key(request, "demo_yes_no") cache[key] = DemoYesNoResponse(number_of_words=1, text="No") diff --git a/tests/modelgauge_tests/fake_sut.py b/tests/modelgauge_tests/fake_sut.py index 020fdca1d..0c924854b 100644 --- a/tests/modelgauge_tests/fake_sut.py +++ b/tests/modelgauge_tests/fake_sut.py @@ -1,5 +1,6 @@ from modelgauge.prompt import ChatPrompt, TextPrompt -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from pydantic import BaseModel @@ -21,10 +22,10 @@ def __init__(self, uid: str = "fake-sut"): super().__init__(uid) self.evaluate_calls = 0 - def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> FakeSUTRequest: + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> FakeSUTRequest: return FakeSUTRequest(text=prompt.text) - def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> FakeSUTRequest: + def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> FakeSUTRequest: return FakeSUTRequest(text=prompt.messages[-1].text) def evaluate(self, request: FakeSUTRequest) -> FakeSUTResponse: diff --git a/tests/modelgauge_tests/sut_tests/test_anthropic_api.py b/tests/modelgauge_tests/sut_tests/test_anthropic_api.py index 720cd5062..39bf3fcbe 100644 --- a/tests/modelgauge_tests/sut_tests/test_anthropic_api.py +++ b/tests/modelgauge_tests/sut_tests/test_anthropic_api.py @@ -4,7 +4,8 @@ from modelgauge.general import APIException from modelgauge.prompt import TextPrompt -from modelgauge.sut import SUTOptions, SUTResponse +from modelgauge.sut import SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.suts.anthropic_api import AnthropicRequest, AnthropicApiKey, AnthropicSUT from modelgauge.suts.openai_client import OpenAIChatMessage @@ -24,12 +25,12 @@ def simple_anthropic_request(): def test_anthropic_api_translate_request_default_sut_options(fake_sut): prompt = TextPrompt(text="some-text") - request = fake_sut.translate_text_prompt(prompt, SUTOptions()) + request = fake_sut.translate_text_prompt(prompt, ModelOptions()) assert isinstance(request, AnthropicRequest) assert request.model == "fake-model" assert request.messages == [OpenAIChatMessage(content="some-text", role="user")] - assert request.max_tokens == 100 # Default SUTOptions value + assert request.max_tokens == 100 # Default ModelOptions value # Make sure all other attributes are not set request_dict = request.model_dump(exclude_none=False) @@ -40,7 +41,7 @@ def test_anthropic_api_translate_request_default_sut_options(fake_sut): def test_anthropic_api_translate_request_non_default_sut_options(fake_sut): """Test that all possible generation parameters are set correctly.""" - options = SUTOptions( + options = ModelOptions( max_tokens=200, # Overwrite default value temperature=0.5, top_k_per_token=10, diff --git a/tests/modelgauge_tests/sut_tests/test_aws_bedrock_client.py b/tests/modelgauge_tests/sut_tests/test_aws_bedrock_client.py index 9f6f63f51..917d7384d 100644 --- a/tests/modelgauge_tests/sut_tests/test_aws_bedrock_client.py +++ b/tests/modelgauge_tests/sut_tests/test_aws_bedrock_client.py @@ -2,7 +2,8 @@ from unittest.mock import patch from modelgauge.prompt import TextPrompt -from modelgauge.sut import SUTOptions, SUTResponse +from modelgauge.sut import SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.typed_data import is_typeable from modelgauge.suts.aws_bedrock_client import ( @@ -43,7 +44,7 @@ def _make_response(response_text): def test_translate_text_prompt(fake_sut): - default_options = SUTOptions() + default_options = ModelOptions() prompt = TextPrompt(text="some-text") request = fake_sut.translate_text_prompt(prompt, default_options) @@ -52,7 +53,7 @@ def test_translate_text_prompt(fake_sut): assert len(request.messages) == 1 message = request.messages[0] assert message.content == [{"text": "some-text"}] - assert request.inferenceConfig.maxTokens == default_options.max_tokens # Default SUTOptions value + assert request.inferenceConfig.maxTokens == default_options.max_tokens # Default ModelOptions value def test_can_cache_request(): diff --git a/tests/modelgauge_tests/sut_tests/test_baseten_api.py b/tests/modelgauge_tests/sut_tests/test_baseten_api.py index 4c79332d9..cd7ed5791 100644 --- a/tests/modelgauge_tests/sut_tests/test_baseten_api.py +++ b/tests/modelgauge_tests/sut_tests/test_baseten_api.py @@ -1,6 +1,7 @@ import pytest -from modelgauge.sut import SUTOptions, SUTResponse +from modelgauge.sut import SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.suts.baseten_api import ( BasetenPromptSUT, BasetenMessagesSUT, @@ -57,7 +58,7 @@ def _make_response(response_text): def test_baseten_api_translate_prompt_options(baseten_prompt_sut): - options = SUTOptions(max_tokens=200) + options = ModelOptions(max_tokens=200) q = "What is xyzzy?" prompt = TextPrompt(text=q) @@ -68,7 +69,7 @@ def test_baseten_api_translate_prompt_options(baseten_prompt_sut): def test_baseten_api_translate_messages_options(baseten_messages_sut): - options = SUTOptions(max_tokens=200, temperature=0.5, top_p=0.5, top_k_per_token=10, frequency_penalty=2) + options = ModelOptions(max_tokens=200, temperature=0.5, top_p=0.5, top_k_per_token=10, frequency_penalty=2) q = "What is xyzzy?" prompt = TextPrompt(text=q) diff --git a/tests/modelgauge_tests/sut_tests/test_google_genai.py b/tests/modelgauge_tests/sut_tests/test_google_genai.py index 7aba689f1..759b7518e 100644 --- a/tests/modelgauge_tests/sut_tests/test_google_genai.py +++ b/tests/modelgauge_tests/sut_tests/test_google_genai.py @@ -5,7 +5,8 @@ from google.genai.types import GenerateContentConfig, GenerateContentResponse, ThinkingConfig, FinishReason from modelgauge.prompt import TextPrompt -from modelgauge.sut import REFUSAL_RESPONSE, SUTOptions, SUTResponse +from modelgauge.sut import REFUSAL_RESPONSE, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.suts.google_genai import GenAiRequest, GoogleGenAiSUT, GoogleAiApiKey _MODEL_NAME = "some-model" @@ -53,7 +54,7 @@ def mock_model(mock_model_patch, fake_raw_response): def test_google_genai_translate_request_default_options(google_default_sut): prompt = TextPrompt(text="some-text") - request = google_default_sut.translate_text_prompt(prompt, SUTOptions()) + request = google_default_sut.translate_text_prompt(prompt, ModelOptions()) assert request == GenAiRequest( model=_MODEL_NAME, contents="some-text", @@ -71,7 +72,7 @@ def test_google_genai_translate_request_default_options(google_default_sut): def test_google_genai_translate_request_default_options_no_reasoning(google_unreasoning_sut): prompt = TextPrompt(text="some-text") - request = google_unreasoning_sut.translate_text_prompt(prompt, SUTOptions()) + request = google_unreasoning_sut.translate_text_prompt(prompt, ModelOptions()) assert request == GenAiRequest( model=_MODEL_NAME, contents="some-text", @@ -90,7 +91,7 @@ def test_google_genai_translate_request_default_options_no_reasoning(google_unre def test_google_genai_translate_request_generation_options(google_default_sut): prompt = TextPrompt(text="some-text") - options = SUTOptions( + options = ModelOptions( stop_sequences=["stop"], max_tokens=200, temperature=0.5, top_k_per_token=5, frequency_penalty=0.5 ) request = google_default_sut.translate_text_prompt(prompt, options) diff --git a/tests/modelgauge_tests/sut_tests/test_huggingface_api.py b/tests/modelgauge_tests/sut_tests/test_huggingface_api.py index 50221c75a..e0e867f89 100644 --- a/tests/modelgauge_tests/sut_tests/test_huggingface_api.py +++ b/tests/modelgauge_tests/sut_tests/test_huggingface_api.py @@ -3,7 +3,8 @@ from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken from modelgauge.prompt import TextPrompt -from modelgauge.sut import SUTOptions, SUTResponse +from modelgauge.sut import SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.suts.huggingface_api import ( HuggingFaceChatParams, HuggingFaceChatRequest, @@ -23,7 +24,7 @@ def _make_sut_request(text, **params): def test_huggingface_api_translate_text_prompt_request(fake_sut): prompt_text = "some text prompt" - sut_options = SUTOptions(max_tokens=5, temperature=1.0, random="should be ignored") + sut_options = ModelOptions(max_tokens=5, temperature=1.0, random="should be ignored") prompt = TextPrompt(text=prompt_text) request = fake_sut.translate_text_prompt(prompt, sut_options) diff --git a/tests/modelgauge_tests/sut_tests/test_huggingface_chat_completion.py b/tests/modelgauge_tests/sut_tests/test_huggingface_chat_completion.py index 8ddd3ab0d..f151c8948 100644 --- a/tests/modelgauge_tests/sut_tests/test_huggingface_chat_completion.py +++ b/tests/modelgauge_tests/sut_tests/test_huggingface_chat_completion.py @@ -19,7 +19,8 @@ import modelgauge.prompt from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken from modelgauge.prompt import TextPrompt, ChatPrompt, ChatRole -from modelgauge.sut import SUTOptions, SUTResponse, TokenProbability, TopTokens +from modelgauge.sut import SUTResponse, TokenProbability, TopTokens +from modelgauge.model_options import ModelOptions from modelgauge.suts.huggingface_chat_completion import ( HUGGING_FACE_NUM_RETRIES, ChatMessage, @@ -52,7 +53,7 @@ def _make_sut_options(top_logprobs=None): extra_options = {} if top_logprobs is not None: extra_options["top_logprobs"] = top_logprobs - return SUTOptions(max_tokens=5, temperature=1.0, random="random", **extra_options) + return ModelOptions(max_tokens=5, temperature=1.0, random="random", **extra_options) def _make_sut_request(top_logprobs: Optional[int] = None): diff --git a/tests/modelgauge_tests/sut_tests/test_indirect_sut.py b/tests/modelgauge_tests/sut_tests/test_indirect_sut.py index 0e3800d9b..ec486ef7e 100644 --- a/tests/modelgauge_tests/sut_tests/test_indirect_sut.py +++ b/tests/modelgauge_tests/sut_tests/test_indirect_sut.py @@ -5,7 +5,7 @@ from fastapi.testclient import TestClient from modelgauge.prompt import TextPrompt -from modelgauge.sut import SUTOptions +from modelgauge.model_options import ModelOptions from modelgauge.suts.indirect_sut import ( IndirectSUT, IndirectSUTRequest, @@ -33,7 +33,7 @@ def sut(self): def test_translate_text_prompt(self, sut): prompt = TextPrompt(text="text") - options = SUTOptions(max_tokens=20, temperature=0.3) + options = ModelOptions(max_tokens=20, temperature=0.3) request = sut.translate_text_prompt(prompt, options) diff --git a/tests/modelgauge_tests/sut_tests/test_meta_llama.py b/tests/modelgauge_tests/sut_tests/test_meta_llama.py index cf7ddf766..c6e09fccb 100644 --- a/tests/modelgauge_tests/sut_tests/test_meta_llama.py +++ b/tests/modelgauge_tests/sut_tests/test_meta_llama.py @@ -3,7 +3,8 @@ from llama_api_client.types import CreateChatCompletionResponse from modelgauge.prompt import TextPrompt -from modelgauge.sut import SUTOptions, SUTResponse +from modelgauge.sut import SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.suts.meta_llama_client import InputMessage, MetaLlamaApiKey, MetaLlamaChatRequest, MetaLlamaSUT from pytest import fixture from requests import HTTPError # type:ignore @@ -45,7 +46,7 @@ def sut(): def test_translate_text_prompt(sut): - sut_options = SUTOptions() + sut_options = ModelOptions() result = sut.translate_text_prompt(TextPrompt(text="Why did the chicken cross the road?"), sut_options) assert result == MetaLlamaChatRequest( model="a_model", diff --git a/tests/modelgauge_tests/sut_tests/test_mistral_sut.py b/tests/modelgauge_tests/sut_tests/test_mistral_sut.py index 11d8f0a29..9da765905 100644 --- a/tests/modelgauge_tests/sut_tests/test_mistral_sut.py +++ b/tests/modelgauge_tests/sut_tests/test_mistral_sut.py @@ -6,7 +6,8 @@ ) from modelgauge.prompt import TextPrompt -from modelgauge.sut import SUTOptions, SUTResponse +from modelgauge.sut import SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.suts.mistral_client import MistralAIAPIKey from modelgauge.suts.mistral_sut import ( MistralAIResponse, @@ -57,7 +58,7 @@ class TestMistralAISut: def test_request(self, sut, req): translated_req = sut.translate_text_prompt( - TextPrompt(text="Why did the chicken cross the road?"), SUTOptions(temperature=0.3, max_tokens=91) + TextPrompt(text="Why did the chicken cross the road?"), ModelOptions(temperature=0.3, max_tokens=91) ) assert translated_req.model_dump(exclude_none=True) == req diff --git a/tests/modelgauge_tests/sut_tests/test_nvidia_nim_api_client.py b/tests/modelgauge_tests/sut_tests/test_nvidia_nim_api_client.py index b9921c5d0..4000ec715 100644 --- a/tests/modelgauge_tests/sut_tests/test_nvidia_nim_api_client.py +++ b/tests/modelgauge_tests/sut_tests/test_nvidia_nim_api_client.py @@ -7,7 +7,8 @@ from openai.types.chat import ChatCompletion from modelgauge.prompt import TextPrompt -from modelgauge.sut import SUTOptions, SUTResponse +from modelgauge.sut import SUTResponse +from modelgauge.model_options import ModelOptions def _make_client(): @@ -17,7 +18,7 @@ def _make_client(): def test_openai_chat_translate_request(): client = _make_client() prompt = TextPrompt(text="some-text") - request = client.translate_text_prompt(prompt, SUTOptions()) + request = client.translate_text_prompt(prompt, ModelOptions()) assert request == OpenAIChatRequest( model="some-model", messages=[OpenAIChatMessage(content="some-text", role="user")], diff --git a/tests/modelgauge_tests/sut_tests/test_openai_client.py b/tests/modelgauge_tests/sut_tests/test_openai_client.py index b96e7fba4..d24f09fde 100644 --- a/tests/modelgauge_tests/sut_tests/test_openai_client.py +++ b/tests/modelgauge_tests/sut_tests/test_openai_client.py @@ -4,7 +4,8 @@ from openai.types.chat import ChatCompletion from modelgauge.prompt import TextPrompt -from modelgauge.sut import SUTOptions, SUTResponse, TokenProbability, TopTokens +from modelgauge.sut import SUTResponse, TokenProbability, TopTokens +from modelgauge.model_options import ModelOptions from modelgauge.suts.openai_client import ( OpenAIApiKey, OpenAIChat, @@ -71,7 +72,7 @@ def test_openai_constructor(): def test_openai_chat_translate_request(): client = _make_client() prompt = TextPrompt(text="some-text") - request = client.translate_text_prompt(prompt, SUTOptions()) + request = client.translate_text_prompt(prompt, ModelOptions()) assert request == OpenAIChatRequest( model="some-model", messages=[OpenAIChatMessage(content="some-text", role="user")], @@ -82,7 +83,7 @@ def test_openai_chat_translate_request(): def test_openai_chat_translate_request_logprobs(): client = _make_client() prompt = TextPrompt(text="some-text") - request = client.translate_text_prompt(prompt, SUTOptions(top_logprobs=2)) + request = client.translate_text_prompt(prompt, ModelOptions(top_logprobs=2)) assert request == OpenAIChatRequest( model="some-model", messages=[OpenAIChatMessage(content="some-text", role="user")], @@ -96,7 +97,7 @@ def test_openai_chat_translate_request_excessive_logprobs(): client = _make_client() # Set value above limit of 20 prompt = TextPrompt(text="some-text") - request = client.translate_text_prompt(prompt, SUTOptions(top_logprobs=21)) + request = client.translate_text_prompt(prompt, ModelOptions(top_logprobs=21)) assert request == OpenAIChatRequest( model="some-model", messages=[OpenAIChatMessage(content="some-text", role="user")], diff --git a/tests/modelgauge_tests/sut_tests/test_together_client.py b/tests/modelgauge_tests/sut_tests/test_together_client.py index 09fab4f97..d7765668c 100644 --- a/tests/modelgauge_tests/sut_tests/test_together_client.py +++ b/tests/modelgauge_tests/sut_tests/test_together_client.py @@ -7,7 +7,8 @@ from modelgauge.general import APIException from modelgauge.prompt import ChatMessage, ChatPrompt, ChatRole, TextPrompt from modelgauge.prompt_formatting import format_chat -from modelgauge.sut import SUTOptions, SUTResponse, TokenProbability, TopTokens +from modelgauge.sut import SUTResponse, TokenProbability, TopTokens +from modelgauge.model_options import ModelOptions from modelgauge.suts.together_client import ( TogetherApiKey, TogetherChatResponse, @@ -84,7 +85,7 @@ def _make_client(sut_class): def test_together_translate_text_prompt_request(sut_class, request_class): client = _make_client(sut_class) prompt = TextPrompt(text="some-text") - request = client.translate_text_prompt(prompt, SUTOptions()) + request = client.translate_text_prompt(prompt, ModelOptions()) assert request == request_class( model="some-model", prompt="some-text", @@ -107,7 +108,7 @@ def test_together_translate_chat_prompt_request(sut_class, request_class): ChatMessage(text="more-text", role=ChatRole.sut), ] ) - request = client.translate_chat_prompt(prompt, SUTOptions()) + request = client.translate_chat_prompt(prompt, ModelOptions()) assert request == request_class( model="some-model", prompt=format_chat(prompt, user_role="user", sut_role="assistant"), @@ -119,7 +120,7 @@ def test_together_translate_chat_prompt_request(sut_class, request_class): def test_together_chat_translate_text_prompt_request(): client = _make_client(TogetherChatSUT) prompt = TextPrompt(text="some-text") - request = client.translate_text_prompt(prompt, SUTOptions()) + request = client.translate_text_prompt(prompt, ModelOptions()) assert request == TogetherChatRequest( model="some-model", messages=[TogetherChatRequest.Message(content="some-text", role="user")], @@ -136,7 +137,7 @@ def test_together_chat_translate_chat_prompt_request(): ChatMessage(text="more-text", role=ChatRole.sut), ] ) - request = client.translate_chat_prompt(prompt, SUTOptions()) + request = client.translate_chat_prompt(prompt, ModelOptions()) assert request == TogetherChatRequest( model="some-model", messages=[ @@ -157,7 +158,7 @@ def test_together_chat_translate_chat_prompt_request(): def test_together_translate_request_logprobs(sut_class, request_class): client = _make_client(sut_class) prompt = TextPrompt(text="some-text") - request = client.translate_text_prompt(prompt, SUTOptions(top_logprobs=1)) + request = client.translate_text_prompt(prompt, ModelOptions(top_logprobs=1)) assert request == request_class( model="some-model", prompt="some-text", @@ -170,7 +171,7 @@ def test_together_translate_request_logprobs(sut_class, request_class): def test_together_chat_translate_request_logprobs(): client = _make_client(TogetherChatSUT) prompt = TextPrompt(text="some-text") - request = client.translate_text_prompt(prompt, SUTOptions(top_logprobs=1)) + request = client.translate_text_prompt(prompt, ModelOptions(top_logprobs=1)) assert request == TogetherChatRequest( model="some-model", messages=[TogetherChatRequest.Message(content="some-text", role="user")], @@ -563,17 +564,17 @@ def sut(self): def test_translate_text_prompt_sets_max_tokens(self, sut): prompt = TextPrompt(text="some-text") - options = SUTOptions(max_tokens=50) + options = ModelOptions(max_tokens=50) request = sut.translate_text_prompt(prompt, options) assert request.max_tokens == 50 assert request.max_tokens_excl_thinking == 50 - options = SUTOptions(max_tokens=50, max_total_output_tokens=200) + options = ModelOptions(max_tokens=50, max_total_output_tokens=200) request = sut.translate_text_prompt(prompt, options) assert request.max_tokens == 200 assert request.max_tokens_excl_thinking == 50 - options = SUTOptions(max_total_output_tokens=200) + options = ModelOptions(max_total_output_tokens=200) request = sut.translate_text_prompt(prompt, options) assert request.max_tokens == 200 assert request.max_tokens_excl_thinking == 100 # Default max tokens @@ -581,17 +582,17 @@ def test_translate_text_prompt_sets_max_tokens(self, sut): def test_translate_chat_prompt_sets_max_tokens(self, sut): prompt = ChatPrompt(messages=[]) - options = SUTOptions(max_tokens=50) + options = ModelOptions(max_tokens=50) request = sut.translate_chat_prompt(prompt, options) assert request.max_tokens == 50 assert request.max_tokens_excl_thinking == 50 - options = SUTOptions(max_tokens=50, max_total_output_tokens=200) + options = ModelOptions(max_tokens=50, max_total_output_tokens=200) request = sut.translate_chat_prompt(prompt, options) assert request.max_tokens == 200 assert request.max_tokens_excl_thinking == 50 - options = SUTOptions(max_total_output_tokens=200) + options = ModelOptions(max_total_output_tokens=200) request = sut.translate_chat_prompt(prompt, options) assert request.max_tokens == 200 assert request.max_tokens_excl_thinking == 100 # Default max tokens diff --git a/tests/modelgauge_tests/sut_tests/test_vertexai_mistral_sut.py b/tests/modelgauge_tests/sut_tests/test_vertexai_mistral_sut.py index 6cfdaf028..1e8814513 100644 --- a/tests/modelgauge_tests/sut_tests/test_vertexai_mistral_sut.py +++ b/tests/modelgauge_tests/sut_tests/test_vertexai_mistral_sut.py @@ -1,6 +1,7 @@ import pytest from modelgauge.prompt import TextPrompt -from modelgauge.sut import SUTOptions, SUTResponse +from modelgauge.sut import SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.suts.vertexai_client import VertexAIProjectId, VertexAIRegion from modelgauge.suts.vertexai_mistral_sut import ( VertexAIMistralAISut, @@ -60,7 +61,7 @@ class TestMistralAISut: def test_request(self, sut, req): translated_req = sut.translate_text_prompt( - TextPrompt(text="Why did the chicken cross the road?"), options=SUTOptions(temperature=0.5, max_tokens=17) + TextPrompt(text="Why did the chicken cross the road?"), options=ModelOptions(temperature=0.5, max_tokens=17) ) assert translated_req.model_dump(exclude_none=True) == req diff --git a/tests/modelgauge_tests/test_cli.py b/tests/modelgauge_tests/test_cli.py index 67de8c497..61e9af6ae 100644 --- a/tests/modelgauge_tests/test_cli.py +++ b/tests/modelgauge_tests/test_cli.py @@ -18,7 +18,8 @@ from modelgauge.ensemble_annotator import EnsembleAnnotator from modelgauge.preflight import check_secrets, listify from modelgauge.secret_values import InjectSecret -from modelgauge.sut import SUT, SUTOptions +from modelgauge.sut import SUT +from modelgauge.model_options import ModelOptions from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS from modelgauge.test_registry import TESTS @@ -113,7 +114,7 @@ def test_run_sut_with_options(mock_translate_text_prompt): ) options_arg = mock_translate_text_prompt.call_args_list[0][0][1] - assert options_arg == SUTOptions(max_tokens=42, temperature=0.5, top_p=0.0, top_k_per_token=0) + assert options_arg == ModelOptions(max_tokens=42, temperature=0.5, top_p=0.0, top_k_per_token=0) def test_run_annotator_demo(): diff --git a/tests/modelgauge_tests/test_modelship_sut.py b/tests/modelgauge_tests/test_modelship_sut.py index 8afbf814a..d0ae1cc02 100644 --- a/tests/modelgauge_tests/test_modelship_sut.py +++ b/tests/modelgauge_tests/test_modelship_sut.py @@ -1,7 +1,7 @@ from unittest.mock import patch from modelgauge.prompt import ChatPrompt, ChatMessage, ChatRole -from modelgauge.sut import SUTOptions +from modelgauge.model_options import ModelOptions from modelgauge.sut_definition import SUTDefinition from modelgauge.suts.modelship_sut import ModelShipSUTFactory @@ -15,7 +15,7 @@ def test_basic_request_with_vllm_options(): factory = ModelShipSUTFactory(raw_secrets={"modelship": {"api_key": "whatever"}}) sut = factory.make_sut(definition) prompt = ChatPrompt(messages=[ChatMessage(text="Why not?", role=ChatRole.user)]) - request = sut.translate_chat_prompt(prompt, SUTOptions()) + request = sut.translate_chat_prompt(prompt, ModelOptions()) with patch("openai.resources.chat.completions.Completions.create") as fake_create: sut.evaluate(request) diff --git a/tests/modelgauge_tests/test_object_creation.py b/tests/modelgauge_tests/test_object_creation.py index 6fe99e49e..45eab8913 100644 --- a/tests/modelgauge_tests/test_object_creation.py +++ b/tests/modelgauge_tests/test_object_creation.py @@ -11,7 +11,8 @@ from modelgauge.locales import EN_US # see "workaround" below from modelgauge.prompt import TextPrompt from modelgauge.record_init import InitializationRecord -from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_registry import SUTS from modelgauge.suts.baseten_api import BasetenSUT @@ -123,7 +124,7 @@ def test_all_suts_can_evaluate(sut_name): if AcceptsTextPrompt in sut.capabilities: native_request = sut.translate_text_prompt( TextPrompt(text="What is your name?"), - SUTOptions(max_tokens=3), + ModelOptions(max_tokens=3), ) else: raise AssertionError("Update test to handle other kinds of prompts.") @@ -144,7 +145,7 @@ def test_can_cache_all_sut_responses(sut_name, tmpdir): TextPrompt( text="What is your name?", ), - options=SUTOptions(max_tokens=3), + options=ModelOptions(max_tokens=3), ) else: raise AssertionError("Update test to handle other kinds of prompts.") diff --git a/tests/modelgauge_tests/test_pipeline_runner.py b/tests/modelgauge_tests/test_pipeline_runner.py index 82f56312f..91baabc0b 100644 --- a/tests/modelgauge_tests/test_pipeline_runner.py +++ b/tests/modelgauge_tests/test_pipeline_runner.py @@ -12,7 +12,7 @@ from modelgauge.ensemble_annotator import EnsembleAnnotator from modelgauge.pipeline_runner import AnnotatorRunner, PromptPlusAnnotatorRunner, PromptRunner, build_runner from modelgauge.prompt_pipeline import PromptSink, PromptSource, PromptSutAssigner, PromptSutWorkers -from modelgauge.sut import SUTOptions +from modelgauge.model_options import ModelOptions NUM_PROMPTS = 3 # Number of prompts in the prompts file @@ -177,7 +177,7 @@ def test_output_dir(self, tmp_path, runner_basic): assert runner_basic.output_dir() == tmp_path / runner_basic.run_id def test_pipeline_segments(self, tmp_path, prompts_dataset, prompts_file, suts): - sut_options = SUTOptions(max_tokens=42) + sut_options = ModelOptions(max_tokens=42) runner = PromptRunner( suts=suts, num_workers=20, input_dataset=prompts_dataset, output_dir=tmp_path, sut_options=sut_options ) @@ -295,7 +295,7 @@ def test_output_dir(self, tmp_path, runner_basic): assert runner_basic.output_dir() == tmp_path / runner_basic.run_id def test_pipeline_segments(self, tmp_path, prompts_dataset, prompts_file, suts, annotators): - sut_options = SUTOptions(max_tokens=42) + sut_options = ModelOptions(max_tokens=42) runner = PromptPlusAnnotatorRunner( suts=suts, annotators=annotators, diff --git a/tests/modelgauge_tests/test_prompt_pipeline.py b/tests/modelgauge_tests/test_prompt_pipeline.py index 0fd107c7e..e8f427bbb 100644 --- a/tests/modelgauge_tests/test_prompt_pipeline.py +++ b/tests/modelgauge_tests/test_prompt_pipeline.py @@ -22,7 +22,8 @@ PromptSutWorkers, ) from modelgauge.single_turn_prompt_response import SUTInteraction, TestItem -from modelgauge.sut import SUTOptions, SUTResponse +from modelgauge.sut import SUTResponse +from modelgauge.model_options import ModelOptions from modelgauge_tests.fake_sut import FakeSUT, FakeSUTRequest, FakeSUTResponse @@ -138,7 +139,7 @@ def test_prompt_sut_worker_sends_prompt_options(suts): mock.return_value = FakeSUTRequest(text="") suts["fake1"].translate_text_prompt = mock prompt = TextPrompt(text="a prompt") - sut_options = SUTOptions(max_tokens=42, top_p=0.5, temperature=0.5) + sut_options = ModelOptions(max_tokens=42, top_p=0.5, temperature=0.5) prompt_with_context = TestItem(source_id="1", prompt=prompt) w = PromptSutWorkers(suts, sut_options=sut_options) diff --git a/tests/modelgauge_tests/test_records.py b/tests/modelgauge_tests/test_records.py index 61dce3d36..f01834344 100644 --- a/tests/modelgauge_tests/test_records.py +++ b/tests/modelgauge_tests/test_records.py @@ -8,7 +8,8 @@ SUTResponseAnnotations, TestItem, ) -from modelgauge.sut import SUTOptions, SUTResponse +from modelgauge.sut import SUTResponse +from modelgauge.model_options import ModelOptions from pydantic import BaseModel @@ -39,7 +40,7 @@ def test_serialize_test_record(): ), test_uid="some-test", test_initialization=InitializationRecord(module="some-module", class_name="test-class", args=[], kwargs={}), - sut_options=SUTOptions(max_tokens=17), + sut_options=ModelOptions(max_tokens=17), dependency_versions={"d1": "v1"}, sut_uid="some-sut", sut_initialization=InitializationRecord( @@ -164,7 +165,7 @@ def test_serialize_test_record(): def test_round_trip_test_item(): prompt = TestItem( - prompt=TextPrompt(text="some-text", options=SUTOptions(max_tokens=17)), + prompt=TextPrompt(text="some-text", options=ModelOptions(max_tokens=17)), source_id="id01", context=MockContext(context_field="prompt-context"), ) diff --git a/tests/modelgauge_tests/test_test_decorator.py b/tests/modelgauge_tests/test_test_decorator.py index 3e82d08d0..a0902eab4 100644 --- a/tests/modelgauge_tests/test_test_decorator.py +++ b/tests/modelgauge_tests/test_test_decorator.py @@ -3,7 +3,7 @@ from modelgauge.prompt import ChatPrompt, TextPrompt from modelgauge.record_init import InitializationRecord from modelgauge.single_turn_prompt_response import TestItem -from modelgauge.sut import SUTOptions +from modelgauge.model_options import ModelOptions from modelgauge.sut_capabilities import ( AcceptsChatPrompt, AcceptsTextPrompt, @@ -154,7 +154,7 @@ def test_logprobs_required_not_requested(): @modelgauge_test(requires_sut_capabilities=[AcceptsTextPrompt]) class LogprobsNotRequiredAndRequested(SomePromptResponseTest): - _sut_options = SUTOptions(top_logprobs=1) + _sut_options = ModelOptions(top_logprobs=1) def make_test_items(self, dependency_helper): return [ @@ -174,7 +174,7 @@ def test_logprobs_not_required_and_requested(): @modelgauge_test(requires_sut_capabilities=[ProducesPerTokenLogProbabilities, AcceptsTextPrompt]) class LogprobsRequiredAndRequested(SomePromptResponseTest): - _sut_options = SUTOptions(top_logprobs=1) + _sut_options = ModelOptions(top_logprobs=1) def make_test_items(self, dependency_helper): return [ From ed7ccf7c908e156799cdf9e31e95079238e3dbd7 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 12 Jan 2026 16:00:39 -0800 Subject: [PATCH 2/5] no default max_tokens --- src/modelgauge/model_options.py | 8 ++++++-- src/modelgauge/suts/together_client.py | 7 +++++-- .../sut_tests/test_anthropic_api.py | 1 - .../sut_tests/test_google_genai.py | 4 ++-- .../sut_tests/test_nvidia_nim_api_client.py | 2 +- .../sut_tests/test_openai_client.py | 6 +++--- .../sut_tests/test_together_client.py | 16 ++++++++-------- 7 files changed, 25 insertions(+), 19 deletions(-) diff --git a/src/modelgauge/model_options.py b/src/modelgauge/model_options.py index 7d10ead18..4c18299ba 100644 --- a/src/modelgauge/model_options.py +++ b/src/modelgauge/model_options.py @@ -10,7 +10,7 @@ class ModelOptions(BaseModel): Not all SUTs and annotators respect all options. """ - max_tokens: int = 100 + max_tokens: Optional[int] = None """Maximum number of tokens to generate (per completion)""" max_total_output_tokens: Optional[int] = None @@ -45,7 +45,11 @@ class ModelOptions(BaseModel): @model_validator(mode="after") def check_max_total_output_tokens(self): - if self.max_total_output_tokens is not None and self.max_total_output_tokens < self.max_tokens: + if ( + self.max_total_output_tokens is not None + and self.max_tokens is not None + and self.max_total_output_tokens < self.max_tokens + ): raise ValueError( f"Invalid ModelOptions. max_total_output_tokens ({self.max_total_output_tokens}) must be >= max_tokens ({self.max_tokens})." ) diff --git a/src/modelgauge/suts/together_client.py b/src/modelgauge/suts/together_client.py index 11838c178..e0d46e66a 100644 --- a/src/modelgauge/suts/together_client.py +++ b/src/modelgauge/suts/together_client.py @@ -72,7 +72,7 @@ class TogetherCompletionsRequest(BaseModel): # https://docs.together.ai/reference/completions model: str prompt: str - max_tokens: int + max_tokens: int = 100 stop: Optional[List[str]] = None temperature: Optional[float] = None top_p: Optional[float] = None @@ -128,16 +128,19 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> To return self._translate_request(format_chat(prompt, user_role=_USER_ROLE, sut_role=_ASSISTANT_ROLE), options) def _translate_request(self, text, options): + exclude_none_kwargs = {} + if options.max_tokens is not None: + exclude_none_kwargs["max_tokens"] = options.max_tokens return TogetherCompletionsRequest( model=self.model, prompt=text, - max_tokens=options.max_tokens, stop=options.stop_sequences, temperature=options.temperature, top_p=options.top_p, top_k=options.top_k_per_token, repetition_penalty=options.frequency_penalty, logprobs=options.top_logprobs, + **exclude_none_kwargs, ) def evaluate(self, request: TogetherCompletionsRequest) -> TogetherCompletionsResponse: diff --git a/tests/modelgauge_tests/sut_tests/test_anthropic_api.py b/tests/modelgauge_tests/sut_tests/test_anthropic_api.py index 39bf3fcbe..b363e6577 100644 --- a/tests/modelgauge_tests/sut_tests/test_anthropic_api.py +++ b/tests/modelgauge_tests/sut_tests/test_anthropic_api.py @@ -30,7 +30,6 @@ def test_anthropic_api_translate_request_default_sut_options(fake_sut): assert isinstance(request, AnthropicRequest) assert request.model == "fake-model" assert request.messages == [OpenAIChatMessage(content="some-text", role="user")] - assert request.max_tokens == 100 # Default ModelOptions value # Make sure all other attributes are not set request_dict = request.model_dump(exclude_none=False) diff --git a/tests/modelgauge_tests/sut_tests/test_google_genai.py b/tests/modelgauge_tests/sut_tests/test_google_genai.py index 759b7518e..189d636d1 100644 --- a/tests/modelgauge_tests/sut_tests/test_google_genai.py +++ b/tests/modelgauge_tests/sut_tests/test_google_genai.py @@ -60,7 +60,7 @@ def test_google_genai_translate_request_default_options(google_default_sut): contents="some-text", config=GenerateContentConfig( stop_sequences=None, - max_output_tokens=100, + max_output_tokens=None, temperature=None, top_p=None, top_k=None, @@ -78,7 +78,7 @@ def test_google_genai_translate_request_default_options_no_reasoning(google_unre contents="some-text", config=GenerateContentConfig( stop_sequences=None, - max_output_tokens=100, + max_output_tokens=None, temperature=None, top_p=None, top_k=None, diff --git a/tests/modelgauge_tests/sut_tests/test_nvidia_nim_api_client.py b/tests/modelgauge_tests/sut_tests/test_nvidia_nim_api_client.py index 4000ec715..a0ed9e98c 100644 --- a/tests/modelgauge_tests/sut_tests/test_nvidia_nim_api_client.py +++ b/tests/modelgauge_tests/sut_tests/test_nvidia_nim_api_client.py @@ -18,7 +18,7 @@ def _make_client(): def test_openai_chat_translate_request(): client = _make_client() prompt = TextPrompt(text="some-text") - request = client.translate_text_prompt(prompt, ModelOptions()) + request = client.translate_text_prompt(prompt, ModelOptions(max_tokens=100)) assert request == OpenAIChatRequest( model="some-model", messages=[OpenAIChatMessage(content="some-text", role="user")], diff --git a/tests/modelgauge_tests/sut_tests/test_openai_client.py b/tests/modelgauge_tests/sut_tests/test_openai_client.py index d24f09fde..e769e92d4 100644 --- a/tests/modelgauge_tests/sut_tests/test_openai_client.py +++ b/tests/modelgauge_tests/sut_tests/test_openai_client.py @@ -72,7 +72,7 @@ def test_openai_constructor(): def test_openai_chat_translate_request(): client = _make_client() prompt = TextPrompt(text="some-text") - request = client.translate_text_prompt(prompt, ModelOptions()) + request = client.translate_text_prompt(prompt, ModelOptions(max_tokens=100)) assert request == OpenAIChatRequest( model="some-model", messages=[OpenAIChatMessage(content="some-text", role="user")], @@ -87,7 +87,7 @@ def test_openai_chat_translate_request_logprobs(): assert request == OpenAIChatRequest( model="some-model", messages=[OpenAIChatMessage(content="some-text", role="user")], - max_completion_tokens=100, + max_completion_tokens=None, logprobs=True, top_logprobs=2, ) @@ -101,7 +101,7 @@ def test_openai_chat_translate_request_excessive_logprobs(): assert request == OpenAIChatRequest( model="some-model", messages=[OpenAIChatMessage(content="some-text", role="user")], - max_completion_tokens=100, + max_completion_tokens=None, logprobs=True, top_logprobs=20, ) diff --git a/tests/modelgauge_tests/sut_tests/test_together_client.py b/tests/modelgauge_tests/sut_tests/test_together_client.py index d7765668c..4b8af8f61 100644 --- a/tests/modelgauge_tests/sut_tests/test_together_client.py +++ b/tests/modelgauge_tests/sut_tests/test_together_client.py @@ -85,7 +85,7 @@ def _make_client(sut_class): def test_together_translate_text_prompt_request(sut_class, request_class): client = _make_client(sut_class) prompt = TextPrompt(text="some-text") - request = client.translate_text_prompt(prompt, ModelOptions()) + request = client.translate_text_prompt(prompt, ModelOptions(max_tokens=100)) assert request == request_class( model="some-model", prompt="some-text", @@ -108,7 +108,7 @@ def test_together_translate_chat_prompt_request(sut_class, request_class): ChatMessage(text="more-text", role=ChatRole.sut), ] ) - request = client.translate_chat_prompt(prompt, ModelOptions()) + request = client.translate_chat_prompt(prompt, ModelOptions(max_tokens=100)) assert request == request_class( model="some-model", prompt=format_chat(prompt, user_role="user", sut_role="assistant"), @@ -120,7 +120,7 @@ def test_together_translate_chat_prompt_request(sut_class, request_class): def test_together_chat_translate_text_prompt_request(): client = _make_client(TogetherChatSUT) prompt = TextPrompt(text="some-text") - request = client.translate_text_prompt(prompt, ModelOptions()) + request = client.translate_text_prompt(prompt, ModelOptions(max_tokens=100)) assert request == TogetherChatRequest( model="some-model", messages=[TogetherChatRequest.Message(content="some-text", role="user")], @@ -144,7 +144,7 @@ def test_together_chat_translate_chat_prompt_request(): TogetherChatRequest.Message(content="some-text", role="user"), TogetherChatRequest.Message(content="more-text", role="assistant"), ], - max_tokens=100, + max_tokens=None, n=1, ) @@ -162,7 +162,7 @@ def test_together_translate_request_logprobs(sut_class, request_class): assert request == request_class( model="some-model", prompt="some-text", - max_tokens=100, + max_tokens=100, # Default for the completions SUT. n=1, logprobs=1, ) @@ -175,7 +175,7 @@ def test_together_chat_translate_request_logprobs(): assert request == TogetherChatRequest( model="some-model", messages=[TogetherChatRequest.Message(content="some-text", role="user")], - max_tokens=100, + max_tokens=None, n=1, logprobs=1, ) @@ -577,7 +577,7 @@ def test_translate_text_prompt_sets_max_tokens(self, sut): options = ModelOptions(max_total_output_tokens=200) request = sut.translate_text_prompt(prompt, options) assert request.max_tokens == 200 - assert request.max_tokens_excl_thinking == 100 # Default max tokens + assert request.max_tokens_excl_thinking == None # Default max tokens def test_translate_chat_prompt_sets_max_tokens(self, sut): prompt = ChatPrompt(messages=[]) @@ -595,7 +595,7 @@ def test_translate_chat_prompt_sets_max_tokens(self, sut): options = ModelOptions(max_total_output_tokens=200) request = sut.translate_chat_prompt(prompt, options) assert request.max_tokens == 200 - assert request.max_tokens_excl_thinking == 100 # Default max tokens + assert request.max_tokens_excl_thinking == None @pytest.mark.parametrize( "full_text, response_text", [("hmm\\n Output", "Output"), ("hmmm", "")] From 91da4397a33ed599be036eaaeafec71d802ec32f Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 12 Jan 2026 16:24:11 -0800 Subject: [PATCH 3/5] CLI sets defualt max_tokens to 100 instead --- src/modelgauge/cli.py | 3 +-- src/modelgauge/command_line.py | 2 +- tests/modelgauge_tests/test_pipeline_runner.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/modelgauge/cli.py b/src/modelgauge/cli.py index a473c34b3..1ac3eda8e 100644 --- a/src/modelgauge/cli.py +++ b/src/modelgauge/cli.py @@ -142,14 +142,13 @@ def list_secrets() -> None: def run_sut( sut: str, prompt: str, - max_tokens: Optional[int], + max_tokens: int, temp: Optional[float], top_logprobs: Optional[int], top_p: Optional[float], top_k: Optional[int], ): """Send a prompt from the command line to a SUT.""" - # TODO Consider a SUT factory that takes in a SUTDefinition and returns a SUT options = ModelOptions.create_from_arguments(max_tokens, temp, top_p, top_k, top_logprobs) diff --git a/src/modelgauge/command_line.py b/src/modelgauge/command_line.py index 9efacfed6..d95f8df40 100644 --- a/src/modelgauge/command_line.py +++ b/src/modelgauge/command_line.py @@ -72,7 +72,7 @@ def load_local_plugins(_, __, path: pathlib.Path): ) MAX_TOKENS_OPTION = click.option( - "--max-tokens", default=None, type=click.IntRange(1), help="How many tokens to generate for each completion." + "--max-tokens", default=100, type=click.IntRange(1), help="How many tokens to generate for each completion." ) TEMP_OPTION = click.option("--temp", default=None, type=float, help="SUT temperature value.") TOP_P_OPTION = click.option("--top-p", default=None, type=float, help="SUT top-p value.") diff --git a/tests/modelgauge_tests/test_pipeline_runner.py b/tests/modelgauge_tests/test_pipeline_runner.py index 91baabc0b..1b75e4ad3 100644 --- a/tests/modelgauge_tests/test_pipeline_runner.py +++ b/tests/modelgauge_tests/test_pipeline_runner.py @@ -113,7 +113,7 @@ def assert_basic_sut_metadata(metadata): "kwargs": {}, "module": "modelgauge_tests.fake_sut", }, - "sut_options": {"max_tokens": 100}, + "sut_options": {}, }, { "uid": "sut2", @@ -123,7 +123,7 @@ def assert_basic_sut_metadata(metadata): "kwargs": {}, "module": "modelgauge_tests.fake_sut", }, - "sut_options": {"max_tokens": 100}, + "sut_options": {}, }, ] assert metadata["responses"] == { From 83b58376a3c6589b84785a0a351c0f3039c68fc8 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 12 Jan 2026 16:33:25 -0800 Subject: [PATCH 4/5] Log sut_options in run-job and don't pass for annotator-only jobs --- src/modelgauge/pipeline_runner.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/modelgauge/pipeline_runner.py b/src/modelgauge/pipeline_runner.py index 1fbd01267..e03c39ab4 100644 --- a/src/modelgauge/pipeline_runner.py +++ b/src/modelgauge/pipeline_runner.py @@ -26,14 +26,12 @@ def __init__( input_dataset, output_dir, cache_dir=None, - sut_options=ModelOptions(), tag=None, ): self.num_workers = num_workers self.input_dataset = input_dataset self.root_dir = output_dir self.cache_dir = cache_dir - self.sut_options = sut_options self.tag = tag self.pipeline_segments = [] self.start_time = datetime.datetime.now() @@ -120,7 +118,9 @@ def _write_metadata(self): class PromptRunner(PipelineRunner): - def __init__(self, suts, **kwargs): + def __init__(self, suts, sut_options=ModelOptions(), **kwargs): + self.sut_options = sut_options + logger.info(f"Using SUT options: {self.sut_options}") self.suts = suts self.sut_worker = None # Convenience pointer. super().__init__(**kwargs) @@ -278,6 +278,7 @@ def build_runner( sut_uid_col=None, sut_response_col=None, jailbreak=False, + sut_options=None, **kwargs, ): if jailbreak and not (annotators and suts): @@ -304,9 +305,11 @@ def build_runner( ) # Build runner if suts and annotators: - pipeline_runner = PromptPlusAnnotatorRunner(suts=suts, annotators=annotators, input_dataset=dataset, **kwargs) + pipeline_runner = PromptPlusAnnotatorRunner( + suts=suts, annotators=annotators, input_dataset=dataset, sut_options=sut_options, **kwargs + ) elif suts: - pipeline_runner = PromptRunner(suts=suts, input_dataset=dataset, **kwargs) + pipeline_runner = PromptRunner(suts=suts, input_dataset=dataset, sut_options=sut_options, **kwargs) elif annotators: pipeline_runner = AnnotatorRunner(annotators=annotators, input_dataset=dataset, **kwargs) else: From 14e585ea643d85cce3762e7e9a1daf1dd071301b Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 12 Jan 2026 16:44:09 -0800 Subject: [PATCH 5/5] Annotation objects have top_logprobs field --- src/modelgauge/annotation.py | 5 ++++- src/modelgauge/model_options.py | 15 ++++++++++++++- src/modelgauge/sut.py | 15 +-------------- .../suts/huggingface_chat_completion.py | 4 ++-- src/modelgauge/suts/openai_client.py | 4 +--- src/modelgauge/suts/together_client.py | 4 ++-- tests/modelbench_tests/test_run_journal.py | 3 ++- .../sut_tests/test_huggingface_chat_completion.py | 4 ++-- .../sut_tests/test_openai_client.py | 4 ++-- .../sut_tests/test_together_client.py | 4 ++-- tests/modelgauge_tests/test_records.py | 1 + 11 files changed, 33 insertions(+), 30 deletions(-) diff --git a/src/modelgauge/annotation.py b/src/modelgauge/annotation.py index ac8a4e2bd..098557943 100644 --- a/src/modelgauge/annotation.py +++ b/src/modelgauge/annotation.py @@ -1,6 +1,8 @@ +from typing import Dict, Optional, Sequence + from pydantic import BaseModel -from typing import Dict, Optional +from modelgauge.model_options import TopTokens class MissingAnnotationError(Exception): @@ -16,6 +18,7 @@ class SafetyAnnotation(BaseModel): is_safe: bool is_valid: bool = True + top_logprobs: Optional[Sequence[TopTokens]] = None metadata: Optional[dict] = None diff --git a/src/modelgauge/model_options.py b/src/modelgauge/model_options.py index 4c18299ba..c9816525b 100644 --- a/src/modelgauge/model_options.py +++ b/src/modelgauge/model_options.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Sequence from pydantic import BaseModel, model_validator @@ -70,3 +70,16 @@ def create_from_arguments(max_tokens=None, temp=None, top_p=None, top_k=None, to options.top_logprobs = top_logprobs return options + + +class TokenProbability(BaseModel): + """Probability assigned to a given token.""" + + token: str + logprob: float + + +class TopTokens(BaseModel): + """List of most likely tokens and their probabilities.""" + + top_tokens: Sequence[TokenProbability] diff --git a/src/modelgauge/sut.py b/src/modelgauge/sut.py index 5531c7ad8..06865a425 100644 --- a/src/modelgauge/sut.py +++ b/src/modelgauge/sut.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -from modelgauge.model_options import ModelOptions +from modelgauge.model_options import ModelOptions, TopTokens from modelgauge.not_implemented import not_implemented from modelgauge.prompt import ChatPrompt, TextPrompt from modelgauge.ready import Readyable, ReadyResponse @@ -14,19 +14,6 @@ REFUSAL_RESPONSE = "" -class TokenProbability(BaseModel): - """Probability assigned to a given token.""" - - token: str - logprob: float - - -class TopTokens(BaseModel): - """List of most likely tokens and their probabilities.""" - - top_tokens: Sequence[TokenProbability] - - class SUTResponse(BaseModel): """The data that came out of the SUT.""" diff --git a/src/modelgauge/suts/huggingface_chat_completion.py b/src/modelgauge/suts/huggingface_chat_completion.py index 1e31e0508..2ab001685 100644 --- a/src/modelgauge/suts/huggingface_chat_completion.py +++ b/src/modelgauge/suts/huggingface_chat_completion.py @@ -12,8 +12,8 @@ from modelgauge.prompt import TextPrompt, ChatPrompt from modelgauge.retry_decorator import retry from modelgauge.secret_values import InjectSecret -from modelgauge.sut import PromptResponseSUT, SUTResponse, TokenProbability, TopTokens -from modelgauge.model_options import ModelOptions +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS diff --git a/src/modelgauge/suts/openai_client.py b/src/modelgauge/suts/openai_client.py index f373a48a3..182215179 100644 --- a/src/modelgauge/suts/openai_client.py +++ b/src/modelgauge/suts/openai_client.py @@ -18,10 +18,8 @@ from modelgauge.sut import ( PromptResponseSUT, SUTResponse, - TokenProbability, - TopTokens, ) -from modelgauge.model_options import ModelOptions +from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens from modelgauge.sut_capabilities import ( AcceptsChatPrompt, AcceptsTextPrompt, diff --git a/src/modelgauge/suts/together_client.py b/src/modelgauge/suts/together_client.py index e0d46e66a..aa1275829 100644 --- a/src/modelgauge/suts/together_client.py +++ b/src/modelgauge/suts/together_client.py @@ -12,8 +12,8 @@ from modelgauge.prompt_formatting import format_chat from modelgauge.tokenizer import GeneralTokenizer from modelgauge.secret_values import InjectSecret -from modelgauge.sut import PromptResponseSUT, SUTResponse, TokenProbability, TopTokens -from modelgauge.model_options import ModelOptions +from modelgauge.sut import PromptResponseSUT, SUTResponse +from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS diff --git a/tests/modelbench_tests/test_run_journal.py b/tests/modelbench_tests/test_run_journal.py index a2015b859..925442493 100644 --- a/tests/modelbench_tests/test_run_journal.py +++ b/tests/modelbench_tests/test_run_journal.py @@ -14,7 +14,8 @@ from modelbench.benchmark_runner_items import Timer from modelbench.run_journal import RunJournal, for_journal from modelgauge.locales import EN_US -from modelgauge.sut import SUTResponse, TopTokens, TokenProbability +from modelgauge.sut import SUTResponse +from modelgauge.model_options import TokenProbability, TopTokens def assert_no_output(capsys): diff --git a/tests/modelgauge_tests/sut_tests/test_huggingface_chat_completion.py b/tests/modelgauge_tests/sut_tests/test_huggingface_chat_completion.py index f151c8948..de5602444 100644 --- a/tests/modelgauge_tests/sut_tests/test_huggingface_chat_completion.py +++ b/tests/modelgauge_tests/sut_tests/test_huggingface_chat_completion.py @@ -19,8 +19,8 @@ import modelgauge.prompt from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken from modelgauge.prompt import TextPrompt, ChatPrompt, ChatRole -from modelgauge.sut import SUTResponse, TokenProbability, TopTokens -from modelgauge.model_options import ModelOptions +from modelgauge.sut import SUTResponse +from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens from modelgauge.suts.huggingface_chat_completion import ( HUGGING_FACE_NUM_RETRIES, ChatMessage, diff --git a/tests/modelgauge_tests/sut_tests/test_openai_client.py b/tests/modelgauge_tests/sut_tests/test_openai_client.py index e769e92d4..cf4ede89c 100644 --- a/tests/modelgauge_tests/sut_tests/test_openai_client.py +++ b/tests/modelgauge_tests/sut_tests/test_openai_client.py @@ -4,8 +4,8 @@ from openai.types.chat import ChatCompletion from modelgauge.prompt import TextPrompt -from modelgauge.sut import SUTResponse, TokenProbability, TopTokens -from modelgauge.model_options import ModelOptions +from modelgauge.sut import SUTResponse +from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens from modelgauge.suts.openai_client import ( OpenAIApiKey, OpenAIChat, diff --git a/tests/modelgauge_tests/sut_tests/test_together_client.py b/tests/modelgauge_tests/sut_tests/test_together_client.py index 4b8af8f61..f071234cd 100644 --- a/tests/modelgauge_tests/sut_tests/test_together_client.py +++ b/tests/modelgauge_tests/sut_tests/test_together_client.py @@ -7,8 +7,8 @@ from modelgauge.general import APIException from modelgauge.prompt import ChatMessage, ChatPrompt, ChatRole, TextPrompt from modelgauge.prompt_formatting import format_chat -from modelgauge.sut import SUTResponse, TokenProbability, TopTokens -from modelgauge.model_options import ModelOptions +from modelgauge.sut import SUTResponse +from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens from modelgauge.suts.together_client import ( TogetherApiKey, TogetherChatResponse, diff --git a/tests/modelgauge_tests/test_records.py b/tests/modelgauge_tests/test_records.py index f01834344..057696753 100644 --- a/tests/modelgauge_tests/test_records.py +++ b/tests/modelgauge_tests/test_records.py @@ -142,6 +142,7 @@ def test_serialize_test_record(): "k1": { "is_safe": true, "is_valid": false, + "top_logprobs": null, "metadata": null } }