Skip to content

Commit 16b3573

Browse files
Merge pull request #40 from open-sciencelab/think-model
filter think tag when using reasoning models
2 parents c883cf2 + c1a4359 commit 16b3573

2 files changed

Lines changed: 67 additions & 34 deletions

File tree

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,45 @@
11
import math
2+
import re
23
from dataclasses import dataclass, field
3-
from typing import List, Dict, Optional
4+
from typing import Dict, List, Optional
5+
46
import openai
5-
from openai import AsyncOpenAI, RateLimitError, APIConnectionError, APITimeoutError
7+
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
68
from tenacity import (
79
retry,
10+
retry_if_exception_type,
811
stop_after_attempt,
912
wait_exponential,
10-
retry_if_exception_type,
1113
)
1214

13-
from graphgen.models.llm.topk_token_model import TopkTokenModel, Token
14-
from graphgen.models.llm.tokenizer import Tokenizer
1515
from graphgen.models.llm.limitter import RPM, TPM
16+
from graphgen.models.llm.tokenizer import Tokenizer
17+
from graphgen.models.llm.topk_token_model import Token, TopkTokenModel
18+
1619

1720
def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
1821
token_logprobs = response.choices[0].logprobs.content
1922
tokens = []
2023
for token_prob in token_logprobs:
2124
prob = math.exp(token_prob.logprob)
2225
candidate_tokens = [
23-
Token(t.token, math.exp(t.logprob))
24-
for t in token_prob.top_logprobs
26+
Token(t.token, math.exp(t.logprob)) for t in token_prob.top_logprobs
2527
]
2628
token = Token(token_prob.token, prob, top_candidates=candidate_tokens)
2729
tokens.append(token)
2830
return tokens
2931

32+
33+
def filter_think_tags(text: str) -> str:
34+
"""
35+
Remove <think> tags from the text.
36+
If the text contains <think> and </think>, it removes everything between them and the tags themselves.
37+
"""
38+
think_pattern = re.compile(r"<think>.*?</think>", re.DOTALL)
39+
filtered_text = think_pattern.sub("", text).strip()
40+
return filtered_text if filtered_text else text.strip()
41+
42+
3043
@dataclass
3144
class OpenAIModel(TopkTokenModel):
3245
model_name: str = "gpt-4o-mini"
@@ -42,12 +55,11 @@ class OpenAIModel(TopkTokenModel):
4255
rpm: RPM = field(default_factory=lambda: RPM(rpm=1000))
4356
tpm: TPM = field(default_factory=lambda: TPM(tpm=50000))
4457

45-
4658
def __post_init__(self):
4759
assert self.api_key is not None, "Please provide api key to access openai api."
48-
if self.api_key == "":
49-
self.api_key = "none"
50-
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
60+
self.client = AsyncOpenAI(
61+
api_key=self.api_key or "dummy", base_url=self.base_url
62+
)
5163

5264
def _pre_generate(self, text: str, history: List[str]) -> Dict:
5365
kwargs = {
@@ -69,16 +81,19 @@ def _pre_generate(self, text: str, history: List[str]) -> Dict:
6981
assert len(history) % 2 == 0, "History should have even number of elements."
7082
messages = history + messages
7183

72-
kwargs['messages']= messages
84+
kwargs["messages"] = messages
7385
return kwargs
7486

75-
7687
@retry(
7788
stop=stop_after_attempt(5),
7889
wait=wait_exponential(multiplier=1, min=4, max=10),
79-
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
90+
retry=retry_if_exception_type(
91+
(RateLimitError, APIConnectionError, APITimeoutError)
92+
),
8093
)
81-
async def generate_topk_per_token(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
94+
async def generate_topk_per_token(
95+
self, text: str, history: Optional[List[str]] = None
96+
) -> List[Token]:
8297
kwargs = self._pre_generate(text, history)
8398
if self.topk_per_token > 0:
8499
kwargs["logprobs"] = True
@@ -87,9 +102,8 @@ async def generate_topk_per_token(self, text: str, history: Optional[List[str]]
87102
# Limit max_tokens to 1 to avoid long completions
88103
kwargs["max_tokens"] = 1
89104

90-
completion = await self.client.chat.completions.create( # pylint: disable=E1125
91-
model=self.model_name,
92-
**kwargs
105+
completion = await self.client.chat.completions.create( # pylint: disable=E1125
106+
model=self.model_name, **kwargs
93107
)
94108

95109
tokens = get_top_response_tokens(completion)
@@ -99,32 +113,39 @@ async def generate_topk_per_token(self, text: str, history: Optional[List[str]]
99113
@retry(
100114
stop=stop_after_attempt(5),
101115
wait=wait_exponential(multiplier=1, min=4, max=10),
102-
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
116+
retry=retry_if_exception_type(
117+
(RateLimitError, APIConnectionError, APITimeoutError)
118+
),
103119
)
104-
async def generate_answer(self, text: str, history: Optional[List[str]] = None, temperature: int = 0) -> str:
120+
async def generate_answer(
121+
self, text: str, history: Optional[List[str]] = None, temperature: int = 0
122+
) -> str:
105123
kwargs = self._pre_generate(text, history)
106124
kwargs["temperature"] = temperature
107125

108126
prompt_tokens = 0
109-
for message in kwargs['messages']:
110-
prompt_tokens += len(Tokenizer().encode_string(message['content']))
111-
estimated_tokens = prompt_tokens + kwargs['max_tokens']
127+
for message in kwargs["messages"]:
128+
prompt_tokens += len(Tokenizer().encode_string(message["content"]))
129+
estimated_tokens = prompt_tokens + kwargs["max_tokens"]
112130

113131
if self.request_limit:
114132
await self.rpm.wait(silent=True)
115133
await self.tpm.wait(estimated_tokens, silent=True)
116134

117-
completion = await self.client.chat.completions.create( # pylint: disable=E1125
118-
model=self.model_name,
119-
**kwargs
135+
completion = await self.client.chat.completions.create( # pylint: disable=E1125
136+
model=self.model_name, **kwargs
120137
)
121138
if hasattr(completion, "usage"):
122-
self.token_usage.append({
123-
"prompt_tokens": completion.usage.prompt_tokens,
124-
"completion_tokens": completion.usage.completion_tokens,
125-
"total_tokens": completion.usage.total_tokens,
126-
})
127-
return completion.choices[0].message.content
128-
129-
async def generate_inputs_prob(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
139+
self.token_usage.append(
140+
{
141+
"prompt_tokens": completion.usage.prompt_tokens,
142+
"completion_tokens": completion.usage.completion_tokens,
143+
"total_tokens": completion.usage.total_tokens,
144+
}
145+
)
146+
return filter_think_tags(completion.choices[0].message.content)
147+
148+
async def generate_inputs_prob(
149+
self, text: str, history: Optional[List[str]] = None
150+
) -> List[Token]:
130151
raise NotImplementedError

pyproject.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[tool.black]
2+
line-length = 88 # black 默认 88
3+
include = '\.pyi?$'
4+
5+
[tool.isort]
6+
profile = "black" # 一键适配 black
7+
line_length = 88 # 与 black 保持一致
8+
multi_line_output = 3 # black 偏好的括号换行风格
9+
include_trailing_comma = true
10+
force_grid_wrap = 0
11+
use_parentheses = true
12+
ensure_newline_before_comments = true

0 commit comments

Comments
 (0)