Skip to content

Commit 5ef72f9

Browse files
committed
add transformers generation as default
1 parent dadb45e commit 5ef72f9

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

src/tfbench/lm/_hf.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from transformers import AutoModelForCausalLM, AutoTokenizer
22

3-
from ._types import LM, LMAnswer
3+
from ._types import LM, LMAnswer, NoneResponseError
44

55

66
def extract_thinking_content(output: str) -> tuple[str, str | None]:
@@ -33,8 +33,8 @@ def __init__(self, model_name: str, pure: bool = False):
3333

3434
def _gen(self, prompt: str) -> LMAnswer:
3535
messages = [
36-
{"role": "system", "content": self.instruction},
3736
{"role": "user", "content": prompt},
37+
{"role": "system", "content": self.instruction},
3838
]
3939
text = self.tokenizer.apply_chat_template(
4040
messages,
@@ -49,5 +49,8 @@ def _gen(self, prompt: str) -> LMAnswer:
4949
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
5050
output = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
5151

52+
if output is None:
53+
raise NoneResponseError(self.model_name)
54+
5255
content, thinking_content = extract_thinking_content(output)
5356
return LMAnswer(answer=content, reasoning_steps=thinking_content)

0 commit comments

Comments
 (0)