Skip to content

Commit dadb45e

Browse files
EYH0602Copilot
andauthored
feat: add default option using transformers (#67)
* add transformers generation as default * remove None option for router * remove vllm option for ease of dependency * Update src/tfbench/lm/_hf.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update src/tfbench/lm/_hf.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * remove unnecessary imports --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 6733da2 commit dadb45e

File tree

8 files changed

+89
-1729
lines changed

8 files changed

+89
-1729
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ dependencies = [
3434
"tenacity>=9.1.2",
3535
"tiktoken==0.7.0",
3636
"tqdm>=4.66.2",
37+
"transformers[torch]>=4.55.4",
3738
"tree-sitter==0.22.3",
3839
"tree-sitter-haskell==0.21.0",
3940
"types-deprecated>=1.2.15.20250304",
4041
"types-requests>=2.31.0",
41-
"vllm>=0.10.1.1",
4242
]
4343

4444
[build-system]

src/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def main(
1818
def _run(pure: bool):
1919
results: list[EvalResult] = []
2020
split = "pure" if pure else "base"
21+
result_dir = abspath(pjoin("results", model, split))
2122
for i in range(n_repeats):
22-
result_dir = abspath(pjoin("results", model, split))
2323
os.makedirs(result_dir, exist_ok=True)
2424
result_file = pjoin(result_dir, f"run-{i}.jsonl")
2525
r = run_one_model(

src/tfbench/env.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from dotenv import dotenv_values
1+
from dotenv import load_dotenv
22

3-
ENV = dotenv_values(".env")
4-
5-
assert ENV, "No .env file found! Please create one with the required variables."
3+
load_dotenv(override=True) # override existing env vars with those in .env

src/tfbench/experiment.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def run_one_model(
3131
EvalResult: evaluation result including accuracy
3232
"""
3333
client = router(model, pure, effort)
34-
assert client is not None, f"Failed to create client for {model}."
3534

3635
tasks = load_tfb_from_hf("pure" if pure else "base")
3736
gen_results: list[LMAnswer | None] = []

src/tfbench/lm/_hf.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from transformers import AutoModelForCausalLM, AutoTokenizer
2+
3+
from ._types import LM, LMAnswer
4+
5+
6+
def extract_thinking_content(output: str) -> tuple[str, str | None]:
7+
"""Extract the thinking content and the final answer from the model output.
8+
based on <think> and </think> tags.
9+
10+
Args:
11+
output (str): The model output.
12+
Returns:
13+
tuple[str, str | None]: The thinking content and the final answer.
14+
"""
15+
if "<think>" in output and "</think>" in output:
16+
thinking_content = output.split("<think>")[1].split("</think>")[0].strip()
17+
content = output.split("</think>")[-1].strip()
18+
return content, thinking_content
19+
20+
return output, None
21+
22+
23+
class HFChat(LM):
24+
25+
def __init__(self, model_name: str, pure: bool = False):
26+
super().__init__(model_name=model_name, pure=pure)
27+
28+
# load the tokenizer and the model
29+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
30+
self.model = AutoModelForCausalLM.from_pretrained(
31+
model_name, torch_dtype="auto", device_map="auto"
32+
)
33+
34+
def _gen(self, prompt: str) -> LMAnswer:
35+
messages = [
36+
{"role": "system", "content": self.instruction},
37+
{"role": "user", "content": prompt},
38+
]
39+
text = self.tokenizer.apply_chat_template(
40+
messages,
41+
tokenize=False,
42+
add_generation_prompt=True,
43+
)
44+
45+
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
46+
47+
# conduct text completion
48+
generated_ids = self.model.generate(**model_inputs, max_new_tokens=32768)
49+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
50+
output = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
51+
52+
content, thinking_content = extract_thinking_content(output)
53+
return LMAnswer(answer=content, reasoning_steps=thinking_content)

src/tfbench/lm/_vllm.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

src/tfbench/lm/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ._google import GeminiChat, GeminiReasoning, GEMINI_MODELS, GEMINI_TTC_MODELS
1717
from ._anthropic import ClaudeChat, ClaudeReasoning, CLAUDE_MODELS, CLAUDE_TTC_MODELS
1818
from ._ollama import OllamaChat, OLLAMA_TTC_MODELS
19+
from ._hf import HFChat
1920

2021
from ._google import GeminiReasoningEffort
2122
from ._types import ReasoningEffort
@@ -47,7 +48,7 @@ def router(
4748
model_name: str,
4849
pure: bool,
4950
effort: str | None = None,
50-
) -> LM | None:
51+
) -> LM:
5152
"""Route the model name to the appropriate LM class."""
5253
if model_name in OAI_MODELS:
5354
return OpenAIChatCompletion(model_name=model_name, pure=pure)
@@ -89,7 +90,7 @@ def router(
8990
if model_name in OLLAMA_TTC_MODELS:
9091
return OllamaChat(model_name=model_name, pure=pure)
9192

92-
return None
93+
return HFChat(model_name=model_name, pure=pure)
9394

9495

9596
def extract_response(response: ResultE[LMAnswer]) -> str:

0 commit comments

Comments
 (0)