Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions src/modelgauge/reasoning_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Any
from modellogger.log_config import get_logger
from pydantic import BaseModel

from modelgauge.model_options import ModelOptions
from modelgauge.prompt import TextPrompt
from modelgauge.sut import SUTResponse, PromptResponseSUT
from modelgauge.tokenizer import GeneralTokenizer

logger = get_logger(__name__)


class ReasoningRequest(BaseModel):
Comment thread
superdosh marked this conversation as resolved.
request: Any # Request that is actually sent to the model.
max_content_tokens: int | None = None # Number of tokens allowed for content (excluding thinking text).
max_total_tokens: int | None = None # Total number of tokens allowed (thinking + content).


class ThinkingMixin(PromptResponseSUT):
"""
A mixin for SUTs that parses out thinking text from the output.

The output is expected to be in the form: {reasoning text}</think>{content text}.
If max_total_output_tokens is set in ModelOptions, that value will be used in the model call and the content text will be truncated to max_tokens.
Otherwise, max_tokens is used in the model call and everything after </think> is returned as content.
"""

def __init__(self, uid, *args, **kwargs):
super().__init__(uid, *args, **kwargs)
self.tokenizer = GeneralTokenizer()
self.separator = "</think>" # Tag that separates reasoning from content.

def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> ReasoningRequest:
max_total_tokens = options.max_total_output_tokens
if max_total_tokens is None:
max_total_tokens = options.max_tokens
max_content_tokens = options.max_tokens

# Replace max_tokens in raw request with the max total tokens.
options.max_tokens = max_total_tokens
request = super().translate_text_prompt(prompt, options)
return ReasoningRequest(
request=request,
max_content_tokens=max_content_tokens,
max_total_tokens=max_total_tokens,
)

def evaluate(self, request: ReasoningRequest) -> Any:
return super().evaluate(request.request) # type: ignore

def translate_response(self, request: ReasoningRequest, response: Any) -> SUTResponse:
text = super().translate_response(request.request, response).text # type: ignore
Comment thread
wpietri marked this conversation as resolved.

think_close = text.find(self.separator)
if think_close == -1:
# no closing tag: everything is thinking text
return SUTResponse(text="")

reasoning = text[: think_close + len(self.separator)].strip()
content = text[think_close + len(self.separator) :].strip()
self.warn_edge_cases(content, reasoning, request)

# Truncate content
if request.max_content_tokens is not None:
content = self.tokenizer.truncate(content, request.max_content_tokens)
return SUTResponse(text=content)

def warn_edge_cases(self, content, reasoning, request):
if request.max_total_tokens is None:
return
reasoning_tokens = self.tokenizer.count_tokens(reasoning)
content_tokens = self.tokenizer.count_tokens(content)
reasoning_budget = request.request.max_tokens - request.max_content_tokens

if reasoning_tokens >= reasoning_budget and content_tokens + reasoning_tokens >= request.max_total_tokens:
logger.warning(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't do anything with warnings in benchmark runs right now. Should we?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Kurt wanted this logged so that we can do some manual analysis later on.

f"SUT {self.uid} reasoning likely ate into the token budget of the actual output. Consider increasing max_total_output_tokens."
Comment thread
wpietri marked this conversation as resolved.
)
20 changes: 19 additions & 1 deletion src/modelgauge/suts/huggingface_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from requests.exceptions import HTTPError

from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens
from modelgauge.prompt import TextPrompt, ChatPrompt
from modelgauge.reasoning_handlers import ThinkingMixin
from modelgauge.retry_decorator import retry
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTResponse
from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens
from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
Expand Down Expand Up @@ -185,6 +186,14 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> Hu
)


@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt])
class HuggingFaceChatCompletionDedicatedThinkingSUT(ThinkingMixin, HuggingFaceChatCompletionDedicatedSUT):
Comment thread
superdosh marked this conversation as resolved.
"""
A SUT that excludes the reasoning from model output.
Reasoning must be seperated from normal output with a </think> tag (like nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16)
"""


@modelgauge_sut(capabilities=[AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities])
class HuggingFaceChatCompletionServerlessSUT(BaseHuggingFaceChatCompletionSUT):
"""A SUT hosted by an inference provider on huggingface."""
Expand Down Expand Up @@ -259,7 +268,16 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> Hu
None,
HF_SECRET,
)
# Special thinking SUT
SUTS.register(
HuggingFaceChatCompletionDedicatedThinkingSUT,
"nvidia-nemotron-3-nano-30b-a-thinking-excluded-hf",
"nvidia-nemotron-3-nano-30b-a-mia",
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16",
HF_SECRET,
)

# Register serverless SUTs.
SUTS.register(
HuggingFaceChatCompletionServerlessSUT,
"cohere-c4ai-command-a-03-2025-hf",
Expand Down
60 changes: 3 additions & 57 deletions src/modelgauge/suts/together_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Any, List, Optional
from typing import List, Optional

import requests # type:ignore
from modellogger.log_config import get_logger
Expand All @@ -10,13 +10,13 @@
from modelgauge.general import APIException
from modelgauge.prompt import ChatPrompt, ChatRole, TextPrompt
from modelgauge.prompt_formatting import format_chat
from modelgauge.reasoning_handlers import ThinkingMixin
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTResponse
from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens
from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
from modelgauge.tokenizer import GeneralTokenizer

logger = get_logger(__name__)

Expand Down Expand Up @@ -271,64 +271,10 @@ def translate_response(self, request: TogetherChatRequest, response: TogetherCha
return SUTResponse(text=text, top_logprobs=logprobs)


class TogetherThinkingChatRequest(TogetherChatRequest):
# max_tokens is for total output, including thinking text.
max_tokens_excl_thinking: Optional[int] = None


@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt])
class TogetherThinkingSUT(TogetherChatSUT):
class TogetherThinkingSUT(ThinkingMixin, TogetherChatSUT):
"""SUT that preforms reasoning like deepseek-r1"""

def __init__(self, uid: str, model, api_key: TogetherApiKey):
super().__init__(uid, model, api_key)
self.tokenizer = GeneralTokenizer()

def _translate_request(
self, messages: List[TogetherChatRequest.Message], options: ModelOptions
) -> TogetherThinkingChatRequest:
max_tokens = options.max_total_output_tokens
if max_tokens is None:
max_tokens = options.max_tokens
return TogetherThinkingChatRequest(
model=self.model,
messages=messages,
max_tokens=max_tokens,
max_tokens_excl_thinking=options.max_tokens, # This will be ignored by the model but we use it to truncate
stop=options.stop_sequences,
temperature=options.temperature,
top_p=options.top_p,
top_k=options.top_k_per_token,
repetition_penalty=options.frequency_penalty,
)

def translate_response(self, request: TogetherThinkingChatRequest, response: TogetherChatResponse) -> SUTResponse:
assert len(response.choices) == 1, f"Expected 1 completion, got {len(response.choices)}."
choice = response.choices[0]
text = choice.message.content
assert text is not None
response = self._parse_response_text(request.max_tokens_excl_thinking, text)
return SUTResponse(text=response)

def _parse_response_text(self, max_tokens: int | None, text: str) -> str:
"""Discard thinking text and truncate to max tokens."""
# If other reasoning SUTs follow this pattern, this logic can be extracted to a mixin.
# Make sure to move unit tests as well.

# First discard thinking text.
if text.find("<think>") != 0:
raise ValueError(f"Expected {self.uid} response to start with <think> tag. Got: {text}")
think_close = text.find("</think>")
if think_close == -1:
# no closing tag: everything is thinking text
return ""

response = text[think_close + len("</think>") :].strip()

if max_tokens is None:
return response
return self.tokenizer.truncate(response, max_tokens)


@modelgauge_sut(
capabilities=[
Expand Down
4 changes: 4 additions & 0 deletions src/modelgauge/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def encoding(self):
def _get_encoding(self):
pass

def count_tokens(self, text: str) -> int:
tokens = self.encoding.encode(text)
return len(tokens)

def truncate(self, text: str, max_tokens: int) -> str:
tokens = self.encoding.encode(text)
if len(tokens) > max_tokens:
Expand Down
93 changes: 0 additions & 93 deletions tests/modelgauge_tests/sut_tests/test_together_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
TogetherCompletionsRequest,
TogetherCompletionsSUT,
TogetherDedicatedChatSUT,
TogetherThinkingChatRequest,
TogetherThinkingSUT,
)


Expand Down Expand Up @@ -549,94 +547,3 @@ def side_effect(url, headers, json_payload, method):
# Verify non-400 error is re-raised
with pytest.raises(APIException, match="Internal Server Error \\(500\\)"):
sut.evaluate(request)


class TestTogetherThinkingSUT:
@pytest.fixture
def sut(self):
sut = TogetherThinkingSUT(
uid="test-model",
model="some-model",
api_key=TogetherApiKey("some-value"),
)
return sut

def test_translate_text_prompt_sets_max_tokens(self, sut):
prompt = TextPrompt(text="some-text")

options = ModelOptions(max_tokens=50)
request = sut.translate_text_prompt(prompt, options)
assert request.max_tokens == 50
assert request.max_tokens_excl_thinking == 50

options = ModelOptions(max_tokens=50, max_total_output_tokens=200)
request = sut.translate_text_prompt(prompt, options)
assert request.max_tokens == 200
assert request.max_tokens_excl_thinking == 50

options = ModelOptions(max_total_output_tokens=200)
request = sut.translate_text_prompt(prompt, options)
assert request.max_tokens == 200
assert request.max_tokens_excl_thinking == None # Default max tokens

def test_translate_chat_prompt_sets_max_tokens(self, sut):
prompt = ChatPrompt(messages=[])

options = ModelOptions(max_tokens=50)
request = sut.translate_chat_prompt(prompt, options)
assert request.max_tokens == 50
assert request.max_tokens_excl_thinking == 50

options = ModelOptions(max_tokens=50, max_total_output_tokens=200)
request = sut.translate_chat_prompt(prompt, options)
assert request.max_tokens == 200
assert request.max_tokens_excl_thinking == 50

options = ModelOptions(max_total_output_tokens=200)
request = sut.translate_chat_prompt(prompt, options)
assert request.max_tokens == 200
assert request.max_tokens_excl_thinking == None

@pytest.mark.parametrize(
"full_text, response_text", [("<think>hmm</think>\\n Output", "Output"), ("<think>hmmm", "")]
)
def test_translate_response_no_truncation(self, full_text, response_text, sut):
# No max_tokens_excl_thinking in request, so no truncation.
request = TogetherThinkingChatRequest(model="some-model", messages=[])
response_json = _make_response_json(full_text)
response = TogetherChatResponse.model_validate_json(response_json)

result = sut.translate_response(request, response)
assert result.text == response_text

@pytest.mark.parametrize("full_text", ["", "No thinking", "Late thinking <think>hmm</think>"])
def test_improper_response_formatting_raises_error(self, full_text, sut):
request = TogetherThinkingChatRequest(model="some-model", messages=[])
response_json = _make_response_json(full_text)
response = TogetherChatResponse.model_validate_json(response_json)

with pytest.raises(ValueError):
sut.translate_response(request, response)

@pytest.mark.parametrize(
"full_text, response_text",
[
("<think>hmm</think>one two three", "one two"),
("<think></think>one", "one"),
("<think></think>", ""),
("<think>hmmm", ""),
],
)
def test_truncation(self, full_text, response_text):
sut = TogetherThinkingSUT(
uid="test-model",
model="some-model",
api_key=TogetherApiKey("some-value"),
)

request = TogetherThinkingChatRequest(model="some-model", messages=[], max_tokens_excl_thinking=2)
response_json = _make_response_json(full_text)
response = TogetherChatResponse.model_validate_json(response_json)

result = sut.translate_response(request, response)
assert result.text == response_text
Loading