diff --git a/src/modelgauge/reasoning_handlers.py b/src/modelgauge/reasoning_handlers.py
new file mode 100644
index 000000000..395616633
--- /dev/null
+++ b/src/modelgauge/reasoning_handlers.py
@@ -0,0 +1,78 @@
+from typing import Any
+from modellogger.log_config import get_logger
+from pydantic import BaseModel
+
+from modelgauge.model_options import ModelOptions
+from modelgauge.prompt import TextPrompt
+from modelgauge.sut import SUTResponse, PromptResponseSUT
+from modelgauge.tokenizer import GeneralTokenizer
+
+logger = get_logger(__name__)
+
+
+class ReasoningRequest(BaseModel):
+ request: Any # Request that is actually sent to the model.
+ max_content_tokens: int | None = None # Number of tokens allowed for content (excluding thinking text).
+ max_total_tokens: int | None = None # Total number of tokens allowed (thinking + content).
+
+
+class ThinkingMixin(PromptResponseSUT):
+ """
+ A mixin for SUTs that parses out thinking text from the output.
+
+ The output is expected to be in the form: {reasoning text}{content text}.
+ If max_total_output_tokens is set in ModelOptions, that value will be used in the model call and the content text will be truncated to max_tokens.
+ Otherwise, max_tokens is used in the model call and everything after is returned as content.
+ """
+
+ def __init__(self, uid, *args, **kwargs):
+ super().__init__(uid, *args, **kwargs)
+ self.tokenizer = GeneralTokenizer()
+ self.separator = "" # Tag that separates reasoning from content.
+
+ def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> ReasoningRequest:
+ max_total_tokens = options.max_total_output_tokens
+ if max_total_tokens is None:
+ max_total_tokens = options.max_tokens
+ max_content_tokens = options.max_tokens
+
+ # Replace max_tokens in raw request with the max total tokens.
+ options.max_tokens = max_total_tokens
+ request = super().translate_text_prompt(prompt, options)
+ return ReasoningRequest(
+ request=request,
+ max_content_tokens=max_content_tokens,
+ max_total_tokens=max_total_tokens,
+ )
+
+ def evaluate(self, request: ReasoningRequest) -> Any:
+ return super().evaluate(request.request) # type: ignore
+
+ def translate_response(self, request: ReasoningRequest, response: Any) -> SUTResponse:
+ text = super().translate_response(request.request, response).text # type: ignore
+
+ think_close = text.find(self.separator)
+ if think_close == -1:
+ # no closing tag: everything is thinking text
+ return SUTResponse(text="")
+
+ reasoning = text[: think_close + len(self.separator)].strip()
+ content = text[think_close + len(self.separator) :].strip()
+ self.warn_edge_cases(content, reasoning, request)
+
+ # Truncate content
+ if request.max_content_tokens is not None:
+ content = self.tokenizer.truncate(content, request.max_content_tokens)
+ return SUTResponse(text=content)
+
+ def warn_edge_cases(self, content, reasoning, request):
+ if request.max_total_tokens is None:
+ return
+ reasoning_tokens = self.tokenizer.count_tokens(reasoning)
+ content_tokens = self.tokenizer.count_tokens(content)
+ reasoning_budget = request.request.max_tokens - request.max_content_tokens
+
+ if reasoning_tokens >= reasoning_budget and content_tokens + reasoning_tokens >= request.max_total_tokens:
+ logger.warning(
+ f"SUT {self.uid} reasoning likely ate into the token budget of the actual output. Consider increasing max_total_output_tokens."
+ )
diff --git a/src/modelgauge/suts/huggingface_chat_completion.py b/src/modelgauge/suts/huggingface_chat_completion.py
index 2ab001685..84342ee0d 100644
--- a/src/modelgauge/suts/huggingface_chat_completion.py
+++ b/src/modelgauge/suts/huggingface_chat_completion.py
@@ -9,11 +9,12 @@
from requests.exceptions import HTTPError
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
+from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens
from modelgauge.prompt import TextPrompt, ChatPrompt
+from modelgauge.reasoning_handlers import ThinkingMixin
from modelgauge.retry_decorator import retry
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTResponse
-from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens
from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
@@ -185,6 +186,14 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> Hu
)
+@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt])
+class HuggingFaceChatCompletionDedicatedThinkingSUT(ThinkingMixin, HuggingFaceChatCompletionDedicatedSUT):
+ """
+ A SUT that excludes the reasoning from model output.
+ Reasoning must be seperated from normal output with a tag (like nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16)
+ """
+
+
@modelgauge_sut(capabilities=[AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities])
class HuggingFaceChatCompletionServerlessSUT(BaseHuggingFaceChatCompletionSUT):
"""A SUT hosted by an inference provider on huggingface."""
@@ -259,7 +268,16 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> Hu
None,
HF_SECRET,
)
+# Special thinking SUT
+SUTS.register(
+ HuggingFaceChatCompletionDedicatedThinkingSUT,
+ "nvidia-nemotron-3-nano-30b-a-thinking-excluded-hf",
+ "nvidia-nemotron-3-nano-30b-a-mia",
+ "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16",
+ HF_SECRET,
+)
+# Register serverless SUTs.
SUTS.register(
HuggingFaceChatCompletionServerlessSUT,
"cohere-c4ai-command-a-03-2025-hf",
diff --git a/src/modelgauge/suts/together_client.py b/src/modelgauge/suts/together_client.py
index 46dc42dab..911453d69 100644
--- a/src/modelgauge/suts/together_client.py
+++ b/src/modelgauge/suts/together_client.py
@@ -1,5 +1,5 @@
import time
-from typing import Any, List, Optional
+from typing import List, Optional
import requests # type:ignore
from modellogger.log_config import get_logger
@@ -10,13 +10,13 @@
from modelgauge.general import APIException
from modelgauge.prompt import ChatPrompt, ChatRole, TextPrompt
from modelgauge.prompt_formatting import format_chat
+from modelgauge.reasoning_handlers import ThinkingMixin
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTResponse
from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens
from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
-from modelgauge.tokenizer import GeneralTokenizer
logger = get_logger(__name__)
@@ -271,64 +271,10 @@ def translate_response(self, request: TogetherChatRequest, response: TogetherCha
return SUTResponse(text=text, top_logprobs=logprobs)
-class TogetherThinkingChatRequest(TogetherChatRequest):
- # max_tokens is for total output, including thinking text.
- max_tokens_excl_thinking: Optional[int] = None
-
-
@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt])
-class TogetherThinkingSUT(TogetherChatSUT):
+class TogetherThinkingSUT(ThinkingMixin, TogetherChatSUT):
"""SUT that preforms reasoning like deepseek-r1"""
- def __init__(self, uid: str, model, api_key: TogetherApiKey):
- super().__init__(uid, model, api_key)
- self.tokenizer = GeneralTokenizer()
-
- def _translate_request(
- self, messages: List[TogetherChatRequest.Message], options: ModelOptions
- ) -> TogetherThinkingChatRequest:
- max_tokens = options.max_total_output_tokens
- if max_tokens is None:
- max_tokens = options.max_tokens
- return TogetherThinkingChatRequest(
- model=self.model,
- messages=messages,
- max_tokens=max_tokens,
- max_tokens_excl_thinking=options.max_tokens, # This will be ignored by the model but we use it to truncate
- stop=options.stop_sequences,
- temperature=options.temperature,
- top_p=options.top_p,
- top_k=options.top_k_per_token,
- repetition_penalty=options.frequency_penalty,
- )
-
- def translate_response(self, request: TogetherThinkingChatRequest, response: TogetherChatResponse) -> SUTResponse:
- assert len(response.choices) == 1, f"Expected 1 completion, got {len(response.choices)}."
- choice = response.choices[0]
- text = choice.message.content
- assert text is not None
- response = self._parse_response_text(request.max_tokens_excl_thinking, text)
- return SUTResponse(text=response)
-
- def _parse_response_text(self, max_tokens: int | None, text: str) -> str:
- """Discard thinking text and truncate to max tokens."""
- # If other reasoning SUTs follow this pattern, this logic can be extracted to a mixin.
- # Make sure to move unit tests as well.
-
- # First discard thinking text.
- if text.find("") != 0:
- raise ValueError(f"Expected {self.uid} response to start with tag. Got: {text}")
- think_close = text.find("")
- if think_close == -1:
- # no closing tag: everything is thinking text
- return ""
-
- response = text[think_close + len("") :].strip()
-
- if max_tokens is None:
- return response
- return self.tokenizer.truncate(response, max_tokens)
-
@modelgauge_sut(
capabilities=[
diff --git a/src/modelgauge/tokenizer.py b/src/modelgauge/tokenizer.py
index db7309c36..e5e434543 100644
--- a/src/modelgauge/tokenizer.py
+++ b/src/modelgauge/tokenizer.py
@@ -17,6 +17,10 @@ def encoding(self):
def _get_encoding(self):
pass
+ def count_tokens(self, text: str) -> int:
+ tokens = self.encoding.encode(text)
+ return len(tokens)
+
def truncate(self, text: str, max_tokens: int) -> str:
tokens = self.encoding.encode(text)
if len(tokens) > max_tokens:
diff --git a/tests/modelgauge_tests/sut_tests/test_together_client.py b/tests/modelgauge_tests/sut_tests/test_together_client.py
index f071234cd..e865a26c8 100644
--- a/tests/modelgauge_tests/sut_tests/test_together_client.py
+++ b/tests/modelgauge_tests/sut_tests/test_together_client.py
@@ -18,8 +18,6 @@
TogetherCompletionsRequest,
TogetherCompletionsSUT,
TogetherDedicatedChatSUT,
- TogetherThinkingChatRequest,
- TogetherThinkingSUT,
)
@@ -549,94 +547,3 @@ def side_effect(url, headers, json_payload, method):
# Verify non-400 error is re-raised
with pytest.raises(APIException, match="Internal Server Error \\(500\\)"):
sut.evaluate(request)
-
-
-class TestTogetherThinkingSUT:
- @pytest.fixture
- def sut(self):
- sut = TogetherThinkingSUT(
- uid="test-model",
- model="some-model",
- api_key=TogetherApiKey("some-value"),
- )
- return sut
-
- def test_translate_text_prompt_sets_max_tokens(self, sut):
- prompt = TextPrompt(text="some-text")
-
- options = ModelOptions(max_tokens=50)
- request = sut.translate_text_prompt(prompt, options)
- assert request.max_tokens == 50
- assert request.max_tokens_excl_thinking == 50
-
- options = ModelOptions(max_tokens=50, max_total_output_tokens=200)
- request = sut.translate_text_prompt(prompt, options)
- assert request.max_tokens == 200
- assert request.max_tokens_excl_thinking == 50
-
- options = ModelOptions(max_total_output_tokens=200)
- request = sut.translate_text_prompt(prompt, options)
- assert request.max_tokens == 200
- assert request.max_tokens_excl_thinking == None # Default max tokens
-
- def test_translate_chat_prompt_sets_max_tokens(self, sut):
- prompt = ChatPrompt(messages=[])
-
- options = ModelOptions(max_tokens=50)
- request = sut.translate_chat_prompt(prompt, options)
- assert request.max_tokens == 50
- assert request.max_tokens_excl_thinking == 50
-
- options = ModelOptions(max_tokens=50, max_total_output_tokens=200)
- request = sut.translate_chat_prompt(prompt, options)
- assert request.max_tokens == 200
- assert request.max_tokens_excl_thinking == 50
-
- options = ModelOptions(max_total_output_tokens=200)
- request = sut.translate_chat_prompt(prompt, options)
- assert request.max_tokens == 200
- assert request.max_tokens_excl_thinking == None
-
- @pytest.mark.parametrize(
- "full_text, response_text", [("hmm\\n Output", "Output"), ("hmmm", "")]
- )
- def test_translate_response_no_truncation(self, full_text, response_text, sut):
- # No max_tokens_excl_thinking in request, so no truncation.
- request = TogetherThinkingChatRequest(model="some-model", messages=[])
- response_json = _make_response_json(full_text)
- response = TogetherChatResponse.model_validate_json(response_json)
-
- result = sut.translate_response(request, response)
- assert result.text == response_text
-
- @pytest.mark.parametrize("full_text", ["", "No thinking", "Late thinking hmm"])
- def test_improper_response_formatting_raises_error(self, full_text, sut):
- request = TogetherThinkingChatRequest(model="some-model", messages=[])
- response_json = _make_response_json(full_text)
- response = TogetherChatResponse.model_validate_json(response_json)
-
- with pytest.raises(ValueError):
- sut.translate_response(request, response)
-
- @pytest.mark.parametrize(
- "full_text, response_text",
- [
- ("hmmone two three", "one two"),
- ("one", "one"),
- ("", ""),
- ("hmmm", ""),
- ],
- )
- def test_truncation(self, full_text, response_text):
- sut = TogetherThinkingSUT(
- uid="test-model",
- model="some-model",
- api_key=TogetherApiKey("some-value"),
- )
-
- request = TogetherThinkingChatRequest(model="some-model", messages=[], max_tokens_excl_thinking=2)
- response_json = _make_response_json(full_text)
- response = TogetherChatResponse.model_validate_json(response_json)
-
- result = sut.translate_response(request, response)
- assert result.text == response_text
diff --git a/tests/modelgauge_tests/test_reasoning_handlers.py b/tests/modelgauge_tests/test_reasoning_handlers.py
new file mode 100644
index 000000000..e2d219965
--- /dev/null
+++ b/tests/modelgauge_tests/test_reasoning_handlers.py
@@ -0,0 +1,113 @@
+import pytest
+
+from pydantic import BaseModel
+
+from modelgauge.model_options import ModelOptions
+from modelgauge.prompt import TextPrompt
+from modelgauge.reasoning_handlers import ReasoningRequest, ThinkingMixin
+
+from modelgauge.sut import SUTResponse, PromptResponseSUT
+from modelgauge.sut_capabilities import AcceptsTextPrompt
+from modelgauge.sut_decorator import modelgauge_sut
+
+
+class FakeSUTRequest(BaseModel):
+ text: str
+ max_tokens: int | None = None
+
+
+class FakeSUTResponse(BaseModel):
+ text: str
+
+
+class FakeBaseSUT(PromptResponseSUT):
+ def __init__(self, uid: str = "fake-sut"):
+ super().__init__(uid)
+
+ def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> FakeSUTRequest:
+ return FakeSUTRequest(text="prompt", max_tokens=options.max_tokens)
+
+ def evaluate(self, request: FakeSUTRequest) -> FakeSUTResponse:
+ return FakeSUTResponse(text="reasoningresponse")
+
+ def translate_response(self, request: FakeSUTRequest, response: FakeSUTResponse) -> SUTResponse:
+ return SUTResponse(text=response.text)
+
+
+class TestThinkMixin:
+ @pytest.fixture
+ def sut(self):
+ @modelgauge_sut(capabilities=[AcceptsTextPrompt])
+ class ThinkSut(ThinkingMixin, FakeBaseSUT):
+ pass
+
+ return ThinkSut("sut-uid")
+
+ def test_translate_text_prompt_sets_max_tokens(self, sut):
+ prompt = TextPrompt(text="some-text")
+
+ options = ModelOptions(max_tokens=50)
+ request = sut.translate_text_prompt(prompt, options)
+ assert request.request.max_tokens == 50
+ assert request.max_content_tokens == 50
+
+ options = ModelOptions(max_tokens=50, max_total_output_tokens=200)
+ request = sut.translate_text_prompt(prompt, options)
+ assert request.request.max_tokens == 200
+ assert request.max_content_tokens == 50
+
+ options = ModelOptions(max_total_output_tokens=200)
+ request = sut.translate_text_prompt(prompt, options)
+ assert request.request.max_tokens == 200
+ assert request.max_content_tokens == None # Default max tokens
+
+ options = ModelOptions()
+ request = sut.translate_text_prompt(prompt, options)
+ assert request.request.max_tokens == None
+ assert request.max_content_tokens == None
+
+ @pytest.mark.parametrize(
+ "full_text, content_text",
+ [("hmm\n Output", "Output"), ("hmm\n Output", "Output"), ("hmmm", "")],
+ )
+ def test_translate_response_no_truncation(self, full_text, content_text, sut):
+ request = ReasoningRequest(
+ request=FakeSUTRequest(text="", max_tokens=100), max_content_tokens=100, max_total_tokens=100
+ )
+ response = FakeSUTResponse(text=full_text)
+
+ result = sut.translate_response(request, response)
+ assert result.text == content_text
+
+ request = ReasoningRequest(request=FakeSUTRequest(text=""))
+ response = FakeSUTResponse(text=full_text)
+
+ result = sut.translate_response(request, response)
+ assert result.text == content_text
+
+ @pytest.mark.parametrize(
+ "full_text, content_text",
+ [
+ ("hmmone two three", "one two"),
+ ("one", "one"),
+ ("", ""),
+ ("hmmm", ""),
+ ],
+ )
+ def test_truncation(self, full_text, content_text, sut):
+ request = ReasoningRequest(
+ request=FakeSUTRequest(text="", max_tokens=100), max_content_tokens=2, max_total_tokens=100
+ )
+ response = FakeSUTResponse(text=full_text)
+
+ result = sut.translate_response(request, response)
+ assert result.text == content_text
+
+ def test_translate_response_warns_reasoning_over_budget(self, sut, caplog):
+ request = ReasoningRequest(
+ request=FakeSUTRequest(text="", max_tokens=5), max_content_tokens=2, max_total_tokens=5
+ )
+ response = FakeSUTResponse(text="one two three four five")
+
+ result = sut.translate_response(request, response)
+ assert "reasoning likely ate into the token budget of the actual output" in caplog.text