diff --git a/src/modelgauge/reasoning_handlers.py b/src/modelgauge/reasoning_handlers.py new file mode 100644 index 000000000..395616633 --- /dev/null +++ b/src/modelgauge/reasoning_handlers.py @@ -0,0 +1,78 @@ +from typing import Any +from modellogger.log_config import get_logger +from pydantic import BaseModel + +from modelgauge.model_options import ModelOptions +from modelgauge.prompt import TextPrompt +from modelgauge.sut import SUTResponse, PromptResponseSUT +from modelgauge.tokenizer import GeneralTokenizer + +logger = get_logger(__name__) + + +class ReasoningRequest(BaseModel): + request: Any # Request that is actually sent to the model. + max_content_tokens: int | None = None # Number of tokens allowed for content (excluding thinking text). + max_total_tokens: int | None = None # Total number of tokens allowed (thinking + content). + + +class ThinkingMixin(PromptResponseSUT): + """ + A mixin for SUTs that parses out thinking text from the output. + + The output is expected to be in the form: {reasoning text}{content text}. + If max_total_output_tokens is set in ModelOptions, that value will be used in the model call and the content text will be truncated to max_tokens. + Otherwise, max_tokens is used in the model call and everything after is returned as content. + """ + + def __init__(self, uid, *args, **kwargs): + super().__init__(uid, *args, **kwargs) + self.tokenizer = GeneralTokenizer() + self.separator = "" # Tag that separates reasoning from content. + + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> ReasoningRequest: + max_total_tokens = options.max_total_output_tokens + if max_total_tokens is None: + max_total_tokens = options.max_tokens + max_content_tokens = options.max_tokens + + # Replace max_tokens in raw request with the max total tokens. + options.max_tokens = max_total_tokens + request = super().translate_text_prompt(prompt, options) + return ReasoningRequest( + request=request, + max_content_tokens=max_content_tokens, + max_total_tokens=max_total_tokens, + ) + + def evaluate(self, request: ReasoningRequest) -> Any: + return super().evaluate(request.request) # type: ignore + + def translate_response(self, request: ReasoningRequest, response: Any) -> SUTResponse: + text = super().translate_response(request.request, response).text # type: ignore + + think_close = text.find(self.separator) + if think_close == -1: + # no closing tag: everything is thinking text + return SUTResponse(text="") + + reasoning = text[: think_close + len(self.separator)].strip() + content = text[think_close + len(self.separator) :].strip() + self.warn_edge_cases(content, reasoning, request) + + # Truncate content + if request.max_content_tokens is not None: + content = self.tokenizer.truncate(content, request.max_content_tokens) + return SUTResponse(text=content) + + def warn_edge_cases(self, content, reasoning, request): + if request.max_total_tokens is None: + return + reasoning_tokens = self.tokenizer.count_tokens(reasoning) + content_tokens = self.tokenizer.count_tokens(content) + reasoning_budget = request.request.max_tokens - request.max_content_tokens + + if reasoning_tokens >= reasoning_budget and content_tokens + reasoning_tokens >= request.max_total_tokens: + logger.warning( + f"SUT {self.uid} reasoning likely ate into the token budget of the actual output. Consider increasing max_total_output_tokens." + ) diff --git a/src/modelgauge/suts/huggingface_chat_completion.py b/src/modelgauge/suts/huggingface_chat_completion.py index 2ab001685..84342ee0d 100644 --- a/src/modelgauge/suts/huggingface_chat_completion.py +++ b/src/modelgauge/suts/huggingface_chat_completion.py @@ -9,11 +9,12 @@ from requests.exceptions import HTTPError from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken +from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens from modelgauge.prompt import TextPrompt, ChatPrompt +from modelgauge.reasoning_handlers import ThinkingMixin from modelgauge.retry_decorator import retry from modelgauge.secret_values import InjectSecret from modelgauge.sut import PromptResponseSUT, SUTResponse -from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS @@ -185,6 +186,14 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> Hu ) +@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt]) +class HuggingFaceChatCompletionDedicatedThinkingSUT(ThinkingMixin, HuggingFaceChatCompletionDedicatedSUT): + """ + A SUT that excludes the reasoning from model output. + Reasoning must be seperated from normal output with a tag (like nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) + """ + + @modelgauge_sut(capabilities=[AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities]) class HuggingFaceChatCompletionServerlessSUT(BaseHuggingFaceChatCompletionSUT): """A SUT hosted by an inference provider on huggingface.""" @@ -259,7 +268,16 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> Hu None, HF_SECRET, ) +# Special thinking SUT +SUTS.register( + HuggingFaceChatCompletionDedicatedThinkingSUT, + "nvidia-nemotron-3-nano-30b-a-thinking-excluded-hf", + "nvidia-nemotron-3-nano-30b-a-mia", + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", + HF_SECRET, +) +# Register serverless SUTs. SUTS.register( HuggingFaceChatCompletionServerlessSUT, "cohere-c4ai-command-a-03-2025-hf", diff --git a/src/modelgauge/suts/together_client.py b/src/modelgauge/suts/together_client.py index 46dc42dab..911453d69 100644 --- a/src/modelgauge/suts/together_client.py +++ b/src/modelgauge/suts/together_client.py @@ -1,5 +1,5 @@ import time -from typing import Any, List, Optional +from typing import List, Optional import requests # type:ignore from modellogger.log_config import get_logger @@ -10,13 +10,13 @@ from modelgauge.general import APIException from modelgauge.prompt import ChatPrompt, ChatRole, TextPrompt from modelgauge.prompt_formatting import format_chat +from modelgauge.reasoning_handlers import ThinkingMixin from modelgauge.secret_values import InjectSecret from modelgauge.sut import PromptResponseSUT, SUTResponse from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS -from modelgauge.tokenizer import GeneralTokenizer logger = get_logger(__name__) @@ -271,64 +271,10 @@ def translate_response(self, request: TogetherChatRequest, response: TogetherCha return SUTResponse(text=text, top_logprobs=logprobs) -class TogetherThinkingChatRequest(TogetherChatRequest): - # max_tokens is for total output, including thinking text. - max_tokens_excl_thinking: Optional[int] = None - - @modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt]) -class TogetherThinkingSUT(TogetherChatSUT): +class TogetherThinkingSUT(ThinkingMixin, TogetherChatSUT): """SUT that preforms reasoning like deepseek-r1""" - def __init__(self, uid: str, model, api_key: TogetherApiKey): - super().__init__(uid, model, api_key) - self.tokenizer = GeneralTokenizer() - - def _translate_request( - self, messages: List[TogetherChatRequest.Message], options: ModelOptions - ) -> TogetherThinkingChatRequest: - max_tokens = options.max_total_output_tokens - if max_tokens is None: - max_tokens = options.max_tokens - return TogetherThinkingChatRequest( - model=self.model, - messages=messages, - max_tokens=max_tokens, - max_tokens_excl_thinking=options.max_tokens, # This will be ignored by the model but we use it to truncate - stop=options.stop_sequences, - temperature=options.temperature, - top_p=options.top_p, - top_k=options.top_k_per_token, - repetition_penalty=options.frequency_penalty, - ) - - def translate_response(self, request: TogetherThinkingChatRequest, response: TogetherChatResponse) -> SUTResponse: - assert len(response.choices) == 1, f"Expected 1 completion, got {len(response.choices)}." - choice = response.choices[0] - text = choice.message.content - assert text is not None - response = self._parse_response_text(request.max_tokens_excl_thinking, text) - return SUTResponse(text=response) - - def _parse_response_text(self, max_tokens: int | None, text: str) -> str: - """Discard thinking text and truncate to max tokens.""" - # If other reasoning SUTs follow this pattern, this logic can be extracted to a mixin. - # Make sure to move unit tests as well. - - # First discard thinking text. - if text.find("") != 0: - raise ValueError(f"Expected {self.uid} response to start with tag. Got: {text}") - think_close = text.find("") - if think_close == -1: - # no closing tag: everything is thinking text - return "" - - response = text[think_close + len("") :].strip() - - if max_tokens is None: - return response - return self.tokenizer.truncate(response, max_tokens) - @modelgauge_sut( capabilities=[ diff --git a/src/modelgauge/tokenizer.py b/src/modelgauge/tokenizer.py index db7309c36..e5e434543 100644 --- a/src/modelgauge/tokenizer.py +++ b/src/modelgauge/tokenizer.py @@ -17,6 +17,10 @@ def encoding(self): def _get_encoding(self): pass + def count_tokens(self, text: str) -> int: + tokens = self.encoding.encode(text) + return len(tokens) + def truncate(self, text: str, max_tokens: int) -> str: tokens = self.encoding.encode(text) if len(tokens) > max_tokens: diff --git a/tests/modelgauge_tests/sut_tests/test_together_client.py b/tests/modelgauge_tests/sut_tests/test_together_client.py index f071234cd..e865a26c8 100644 --- a/tests/modelgauge_tests/sut_tests/test_together_client.py +++ b/tests/modelgauge_tests/sut_tests/test_together_client.py @@ -18,8 +18,6 @@ TogetherCompletionsRequest, TogetherCompletionsSUT, TogetherDedicatedChatSUT, - TogetherThinkingChatRequest, - TogetherThinkingSUT, ) @@ -549,94 +547,3 @@ def side_effect(url, headers, json_payload, method): # Verify non-400 error is re-raised with pytest.raises(APIException, match="Internal Server Error \\(500\\)"): sut.evaluate(request) - - -class TestTogetherThinkingSUT: - @pytest.fixture - def sut(self): - sut = TogetherThinkingSUT( - uid="test-model", - model="some-model", - api_key=TogetherApiKey("some-value"), - ) - return sut - - def test_translate_text_prompt_sets_max_tokens(self, sut): - prompt = TextPrompt(text="some-text") - - options = ModelOptions(max_tokens=50) - request = sut.translate_text_prompt(prompt, options) - assert request.max_tokens == 50 - assert request.max_tokens_excl_thinking == 50 - - options = ModelOptions(max_tokens=50, max_total_output_tokens=200) - request = sut.translate_text_prompt(prompt, options) - assert request.max_tokens == 200 - assert request.max_tokens_excl_thinking == 50 - - options = ModelOptions(max_total_output_tokens=200) - request = sut.translate_text_prompt(prompt, options) - assert request.max_tokens == 200 - assert request.max_tokens_excl_thinking == None # Default max tokens - - def test_translate_chat_prompt_sets_max_tokens(self, sut): - prompt = ChatPrompt(messages=[]) - - options = ModelOptions(max_tokens=50) - request = sut.translate_chat_prompt(prompt, options) - assert request.max_tokens == 50 - assert request.max_tokens_excl_thinking == 50 - - options = ModelOptions(max_tokens=50, max_total_output_tokens=200) - request = sut.translate_chat_prompt(prompt, options) - assert request.max_tokens == 200 - assert request.max_tokens_excl_thinking == 50 - - options = ModelOptions(max_total_output_tokens=200) - request = sut.translate_chat_prompt(prompt, options) - assert request.max_tokens == 200 - assert request.max_tokens_excl_thinking == None - - @pytest.mark.parametrize( - "full_text, response_text", [("hmm\\n Output", "Output"), ("hmmm", "")] - ) - def test_translate_response_no_truncation(self, full_text, response_text, sut): - # No max_tokens_excl_thinking in request, so no truncation. - request = TogetherThinkingChatRequest(model="some-model", messages=[]) - response_json = _make_response_json(full_text) - response = TogetherChatResponse.model_validate_json(response_json) - - result = sut.translate_response(request, response) - assert result.text == response_text - - @pytest.mark.parametrize("full_text", ["", "No thinking", "Late thinking hmm"]) - def test_improper_response_formatting_raises_error(self, full_text, sut): - request = TogetherThinkingChatRequest(model="some-model", messages=[]) - response_json = _make_response_json(full_text) - response = TogetherChatResponse.model_validate_json(response_json) - - with pytest.raises(ValueError): - sut.translate_response(request, response) - - @pytest.mark.parametrize( - "full_text, response_text", - [ - ("hmmone two three", "one two"), - ("one", "one"), - ("", ""), - ("hmmm", ""), - ], - ) - def test_truncation(self, full_text, response_text): - sut = TogetherThinkingSUT( - uid="test-model", - model="some-model", - api_key=TogetherApiKey("some-value"), - ) - - request = TogetherThinkingChatRequest(model="some-model", messages=[], max_tokens_excl_thinking=2) - response_json = _make_response_json(full_text) - response = TogetherChatResponse.model_validate_json(response_json) - - result = sut.translate_response(request, response) - assert result.text == response_text diff --git a/tests/modelgauge_tests/test_reasoning_handlers.py b/tests/modelgauge_tests/test_reasoning_handlers.py new file mode 100644 index 000000000..e2d219965 --- /dev/null +++ b/tests/modelgauge_tests/test_reasoning_handlers.py @@ -0,0 +1,113 @@ +import pytest + +from pydantic import BaseModel + +from modelgauge.model_options import ModelOptions +from modelgauge.prompt import TextPrompt +from modelgauge.reasoning_handlers import ReasoningRequest, ThinkingMixin + +from modelgauge.sut import SUTResponse, PromptResponseSUT +from modelgauge.sut_capabilities import AcceptsTextPrompt +from modelgauge.sut_decorator import modelgauge_sut + + +class FakeSUTRequest(BaseModel): + text: str + max_tokens: int | None = None + + +class FakeSUTResponse(BaseModel): + text: str + + +class FakeBaseSUT(PromptResponseSUT): + def __init__(self, uid: str = "fake-sut"): + super().__init__(uid) + + def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> FakeSUTRequest: + return FakeSUTRequest(text="prompt", max_tokens=options.max_tokens) + + def evaluate(self, request: FakeSUTRequest) -> FakeSUTResponse: + return FakeSUTResponse(text="reasoningresponse") + + def translate_response(self, request: FakeSUTRequest, response: FakeSUTResponse) -> SUTResponse: + return SUTResponse(text=response.text) + + +class TestThinkMixin: + @pytest.fixture + def sut(self): + @modelgauge_sut(capabilities=[AcceptsTextPrompt]) + class ThinkSut(ThinkingMixin, FakeBaseSUT): + pass + + return ThinkSut("sut-uid") + + def test_translate_text_prompt_sets_max_tokens(self, sut): + prompt = TextPrompt(text="some-text") + + options = ModelOptions(max_tokens=50) + request = sut.translate_text_prompt(prompt, options) + assert request.request.max_tokens == 50 + assert request.max_content_tokens == 50 + + options = ModelOptions(max_tokens=50, max_total_output_tokens=200) + request = sut.translate_text_prompt(prompt, options) + assert request.request.max_tokens == 200 + assert request.max_content_tokens == 50 + + options = ModelOptions(max_total_output_tokens=200) + request = sut.translate_text_prompt(prompt, options) + assert request.request.max_tokens == 200 + assert request.max_content_tokens == None # Default max tokens + + options = ModelOptions() + request = sut.translate_text_prompt(prompt, options) + assert request.request.max_tokens == None + assert request.max_content_tokens == None + + @pytest.mark.parametrize( + "full_text, content_text", + [("hmm\n Output", "Output"), ("hmm\n Output", "Output"), ("hmmm", "")], + ) + def test_translate_response_no_truncation(self, full_text, content_text, sut): + request = ReasoningRequest( + request=FakeSUTRequest(text="", max_tokens=100), max_content_tokens=100, max_total_tokens=100 + ) + response = FakeSUTResponse(text=full_text) + + result = sut.translate_response(request, response) + assert result.text == content_text + + request = ReasoningRequest(request=FakeSUTRequest(text="")) + response = FakeSUTResponse(text=full_text) + + result = sut.translate_response(request, response) + assert result.text == content_text + + @pytest.mark.parametrize( + "full_text, content_text", + [ + ("hmmone two three", "one two"), + ("one", "one"), + ("", ""), + ("hmmm", ""), + ], + ) + def test_truncation(self, full_text, content_text, sut): + request = ReasoningRequest( + request=FakeSUTRequest(text="", max_tokens=100), max_content_tokens=2, max_total_tokens=100 + ) + response = FakeSUTResponse(text=full_text) + + result = sut.translate_response(request, response) + assert result.text == content_text + + def test_translate_response_warns_reasoning_over_budget(self, sut, caplog): + request = ReasoningRequest( + request=FakeSUTRequest(text="", max_tokens=5), max_content_tokens=2, max_total_tokens=5 + ) + response = FakeSUTResponse(text="one two three four five") + + result = sut.translate_response(request, response) + assert "reasoning likely ate into the token budget of the actual output" in caplog.text