-
Notifications
You must be signed in to change notification settings - Fork 28
Thinking SUTs #1453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Thinking SUTs #1453
Changes from all commits
8db060c
6bd0639
8fa79e0
7fa3b24
c7d6444
346a288
76fc2a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| from typing import Any | ||
| from modellogger.log_config import get_logger | ||
| from pydantic import BaseModel | ||
|
|
||
| from modelgauge.model_options import ModelOptions | ||
| from modelgauge.prompt import TextPrompt | ||
| from modelgauge.sut import SUTResponse, PromptResponseSUT | ||
| from modelgauge.tokenizer import GeneralTokenizer | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
||
| class ReasoningRequest(BaseModel): | ||
| request: Any # Request that is actually sent to the model. | ||
| max_content_tokens: int | None = None # Number of tokens allowed for content (excluding thinking text). | ||
| max_total_tokens: int | None = None # Total number of tokens allowed (thinking + content). | ||
|
|
||
|
|
||
| class ThinkingMixin(PromptResponseSUT): | ||
| """ | ||
| A mixin for SUTs that parses out thinking text from the output. | ||
|
|
||
| The output is expected to be in the form: {reasoning text}</think>{content text}. | ||
| If max_total_output_tokens is set in ModelOptions, that value will be used in the model call and the content text will be truncated to max_tokens. | ||
| Otherwise, max_tokens is used in the model call and everything after </think> is returned as content. | ||
| """ | ||
|
|
||
| def __init__(self, uid, *args, **kwargs): | ||
| super().__init__(uid, *args, **kwargs) | ||
| self.tokenizer = GeneralTokenizer() | ||
| self.separator = "</think>" # Tag that separates reasoning from content. | ||
|
|
||
| def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> ReasoningRequest: | ||
| max_total_tokens = options.max_total_output_tokens | ||
| if max_total_tokens is None: | ||
| max_total_tokens = options.max_tokens | ||
| max_content_tokens = options.max_tokens | ||
|
|
||
| # Replace max_tokens in raw request with the max total tokens. | ||
| options.max_tokens = max_total_tokens | ||
| request = super().translate_text_prompt(prompt, options) | ||
| return ReasoningRequest( | ||
| request=request, | ||
| max_content_tokens=max_content_tokens, | ||
| max_total_tokens=max_total_tokens, | ||
| ) | ||
|
|
||
| def evaluate(self, request: ReasoningRequest) -> Any: | ||
| return super().evaluate(request.request) # type: ignore | ||
|
|
||
| def translate_response(self, request: ReasoningRequest, response: Any) -> SUTResponse: | ||
| text = super().translate_response(request.request, response).text # type: ignore | ||
|
wpietri marked this conversation as resolved.
|
||
|
|
||
| think_close = text.find(self.separator) | ||
| if think_close == -1: | ||
| # no closing tag: everything is thinking text | ||
| return SUTResponse(text="") | ||
|
|
||
| reasoning = text[: think_close + len(self.separator)].strip() | ||
| content = text[think_close + len(self.separator) :].strip() | ||
| self.warn_edge_cases(content, reasoning, request) | ||
|
|
||
| # Truncate content | ||
| if request.max_content_tokens is not None: | ||
| content = self.tokenizer.truncate(content, request.max_content_tokens) | ||
| return SUTResponse(text=content) | ||
|
|
||
| def warn_edge_cases(self, content, reasoning, request): | ||
| if request.max_total_tokens is None: | ||
| return | ||
| reasoning_tokens = self.tokenizer.count_tokens(reasoning) | ||
| content_tokens = self.tokenizer.count_tokens(content) | ||
| reasoning_budget = request.request.max_tokens - request.max_content_tokens | ||
|
|
||
| if reasoning_tokens >= reasoning_budget and content_tokens + reasoning_tokens >= request.max_total_tokens: | ||
| logger.warning( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't do anything with warnings in benchmark runs right now. Should we?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. Kurt wanted this logged so that we can do some manual analysis later on. |
||
| f"SUT {self.uid} reasoning likely ate into the token budget of the actual output. Consider increasing max_total_output_tokens." | ||
|
wpietri marked this conversation as resolved.
|
||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.