Skip to content

Commit a116450

Browse files
committed
add transformers generation as default
1 parent 6733da2 commit a116450

6 files changed

Lines changed: 91 additions & 9 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ dependencies = [
3434
"tenacity>=9.1.2",
3535
"tiktoken==0.7.0",
3636
"tqdm>=4.66.2",
37+
"transformers[torch]>=4.55.4",
3738
"tree-sitter==0.22.3",
3839
"tree-sitter-haskell==0.21.0",
3940
"types-deprecated>=1.2.15.20250304",

src/tfbench/env.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from dotenv import dotenv_values
1+
from dotenv import load_dotenv
22

3-
ENV = dotenv_values(".env")
4-
5-
assert ENV, "No .env file found! Please create one with the required variables."
3+
load_dotenv(override=True) # override existing env vars with those in .env

src/tfbench/lm/_hf.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from transformers import AutoModelForCausalLM, AutoTokenizer
2+
3+
from ._types import LM, LMAnswer, NoneResponseError
4+
5+
6+
def extract_thinking_content(output: str) -> tuple[str, str | None]:
7+
"""Extract the thinking content and the final answer from the model output.
8+
based on <think> and </think> tags.
9+
10+
Args:
11+
output (str): The model output.
12+
Returns:
13+
tuple[str, str | None]: The thinking content and the final answer.
14+
"""
15+
if "<think>" in output and "</think>" in output:
16+
thinking_content = output.split("<think>")[1].split("</think>")[0].strip()
17+
content = output.split("</think>")[-1].strip()
18+
return content, thinking_content
19+
20+
return output, None
21+
22+
23+
class HFChat(LM):
24+
25+
def __init__(self, model_name: str, pure: bool = False):
26+
super().__init__(model_name=model_name, pure=pure)
27+
28+
# load the tokenizer and the model
29+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
30+
self.model = AutoModelForCausalLM.from_pretrained(
31+
model_name, torch_dtype="auto", device_map="auto"
32+
)
33+
34+
def _gen(self, prompt: str) -> LMAnswer:
35+
messages = [
36+
{"role": "user", "content": prompt},
37+
{"role": "system", "content": self.instruction},
38+
]
39+
text = self.tokenizer.apply_chat_template(
40+
messages,
41+
tokenize=False,
42+
add_generation_prompt=True,
43+
)
44+
45+
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
46+
47+
# conduct text completion
48+
generated_ids = self.model.generate(**model_inputs, max_new_tokens=32768)
49+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
50+
output = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
51+
52+
if output is None:
53+
raise NoneResponseError(self.model_name)
54+
55+
content, thinking_content = extract_thinking_content(output)
56+
return LMAnswer(answer=content, reasoning_steps=thinking_content)

src/tfbench/lm/_vllm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
import os
12
from vllm import LLM
23
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
34

45
from openai import OpenAI
56

6-
from ..env import ENV
77
from ._types import LM, LMAnswer, NoneResponseError
88

99

@@ -32,9 +32,9 @@ class VLLMOpenAIChatCompletion(LM):
3232
def __init__(self, model_name: str, pure: bool = False):
3333
super().__init__(model_name=model_name, pure=pure)
3434

35-
api_key = ENV.get("VLLM_API_KEY", "")
36-
host = ENV.get("VLLM_HOST", "localhost")
37-
port = ENV.get("VLLM_PORT", "8000")
35+
api_key = os.getenv("VLLM_API_KEY", "")
36+
host = os.getenv("VLLM_HOST", "localhost")
37+
port = os.getenv("VLLM_PORT", "8000")
3838

3939
url = f"http://{host}:{port}/v1"
4040
self.client = OpenAI(

src/tfbench/lm/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ._google import GeminiChat, GeminiReasoning, GEMINI_MODELS, GEMINI_TTC_MODELS
1717
from ._anthropic import ClaudeChat, ClaudeReasoning, CLAUDE_MODELS, CLAUDE_TTC_MODELS
1818
from ._ollama import OllamaChat, OLLAMA_TTC_MODELS
19+
from ._hf import HFChat
1920

2021
from ._google import GeminiReasoningEffort
2122
from ._types import ReasoningEffort
@@ -89,7 +90,7 @@ def router(
8990
if model_name in OLLAMA_TTC_MODELS:
9091
return OllamaChat(model_name=model_name, pure=pure)
9192

92-
return None
93+
return HFChat(model_name=model_name, pure=pure)
9394

9495

9596
def extract_response(response: ResultE[LMAnswer]) -> str:

uv.lock

Lines changed: 26 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)