Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ dependencies = [
"tenacity>=9.1.2",
"tiktoken==0.7.0",
"tqdm>=4.66.2",
"transformers[torch]>=4.55.4",
"tree-sitter==0.22.3",
"tree-sitter-haskell==0.21.0",
"types-deprecated>=1.2.15.20250304",
"types-requests>=2.31.0",
"vllm>=0.10.1.1",
]

[build-system]
Expand Down
2 changes: 1 addition & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def main(
def _run(pure: bool):
results: list[EvalResult] = []
split = "pure" if pure else "base"
result_dir = abspath(pjoin("results", model, split))
for i in range(n_repeats):
result_dir = abspath(pjoin("results", model, split))
os.makedirs(result_dir, exist_ok=True)
result_file = pjoin(result_dir, f"run-{i}.jsonl")
r = run_one_model(
Expand Down
6 changes: 2 additions & 4 deletions src/tfbench/env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from dotenv import dotenv_values
from dotenv import load_dotenv

ENV = dotenv_values(".env")

assert ENV, "No .env file found! Please create one with the required variables."
load_dotenv(override=True) # override existing env vars with those in .env
1 change: 0 additions & 1 deletion src/tfbench/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def run_one_model(
EvalResult: evaluation result including accuracy
"""
client = router(model, pure, effort)
assert client is not None, f"Failed to create client for {model}."

tasks = load_tfb_from_hf("pure" if pure else "base")
gen_results: list[LMAnswer | None] = []
Expand Down
53 changes: 53 additions & 0 deletions src/tfbench/lm/_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from ._types import LM, LMAnswer


def extract_thinking_content(output: str) -> tuple[str, str | None]:
"""Extract the thinking content and the final answer from the model output.
based on <think> and </think> tags.

Args:
output (str): The model output.
Returns:
tuple[str, str | None]: The thinking content and the final answer.
"""
if "<think>" in output and "</think>" in output:
thinking_content = output.split("<think>")[1].split("</think>")[0].strip()
content = output.split("</think>")[-1].strip()
return content, thinking_content

return output, None


class HFChat(LM):

def __init__(self, model_name: str, pure: bool = False):
super().__init__(model_name=model_name, pure=pure)

# load the tokenizer and the model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map="auto"
)

def _gen(self, prompt: str) -> LMAnswer:
messages = [
{"role": "system", "content": self.instruction},
{"role": "user", "content": prompt},
]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)

model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)

# conduct text completion
generated_ids = self.model.generate(**model_inputs, max_new_tokens=32768)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
output = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")

content, thinking_content = extract_thinking_content(output)
return LMAnswer(answer=content, reasoning_steps=thinking_content)
70 changes: 0 additions & 70 deletions src/tfbench/lm/_vllm.py

This file was deleted.

5 changes: 3 additions & 2 deletions src/tfbench/lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ._google import GeminiChat, GeminiReasoning, GEMINI_MODELS, GEMINI_TTC_MODELS
from ._anthropic import ClaudeChat, ClaudeReasoning, CLAUDE_MODELS, CLAUDE_TTC_MODELS
from ._ollama import OllamaChat, OLLAMA_TTC_MODELS
from ._hf import HFChat

from ._google import GeminiReasoningEffort
from ._types import ReasoningEffort
Expand Down Expand Up @@ -47,7 +48,7 @@ def router(
model_name: str,
pure: bool,
effort: str | None = None,
) -> LM | None:
) -> LM:
"""Route the model name to the appropriate LM class."""
if model_name in OAI_MODELS:
return OpenAIChatCompletion(model_name=model_name, pure=pure)
Expand Down Expand Up @@ -89,7 +90,7 @@ def router(
if model_name in OLLAMA_TTC_MODELS:
return OllamaChat(model_name=model_name, pure=pure)

return None
return HFChat(model_name=model_name, pure=pure)


def extract_response(response: ResultE[LMAnswer]) -> str:
Expand Down
Loading