11import math
2+ import re
23from dataclasses import dataclass , field
3- from typing import List , Dict , Optional
4+ from typing import Dict , List , Optional
5+
46import openai
5- from openai import AsyncOpenAI , RateLimitError , APIConnectionError , APITimeoutError
7+ from openai import APIConnectionError , APITimeoutError , AsyncOpenAI , RateLimitError
68from tenacity import (
79 retry ,
10+ retry_if_exception_type ,
811 stop_after_attempt ,
912 wait_exponential ,
10- retry_if_exception_type ,
1113)
1214
13- from graphgen .models .llm .topk_token_model import TopkTokenModel , Token
14- from graphgen .models .llm .tokenizer import Tokenizer
1515from graphgen .models .llm .limitter import RPM , TPM
16+ from graphgen .models .llm .tokenizer import Tokenizer
17+ from graphgen .models .llm .topk_token_model import Token , TopkTokenModel
18+
1619
1720def get_top_response_tokens (response : openai .ChatCompletion ) -> List [Token ]:
1821 token_logprobs = response .choices [0 ].logprobs .content
1922 tokens = []
2023 for token_prob in token_logprobs :
2124 prob = math .exp (token_prob .logprob )
2225 candidate_tokens = [
23- Token (t .token , math .exp (t .logprob ))
24- for t in token_prob .top_logprobs
26+ Token (t .token , math .exp (t .logprob )) for t in token_prob .top_logprobs
2527 ]
2628 token = Token (token_prob .token , prob , top_candidates = candidate_tokens )
2729 tokens .append (token )
2830 return tokens
2931
32+
33+ def filter_think_tags (text : str ) -> str :
34+ """
35+ Remove <think> tags from the text.
36+ If the text contains <think> and </think>, it removes everything between them and the tags themselves.
37+ """
38+ think_pattern = re .compile (r"<think>.*?</think>" , re .DOTALL )
39+ filtered_text = think_pattern .sub ("" , text ).strip ()
40+ return filtered_text if filtered_text else text .strip ()
41+
42+
3043@dataclass
3144class OpenAIModel (TopkTokenModel ):
3245 model_name : str = "gpt-4o-mini"
@@ -42,12 +55,11 @@ class OpenAIModel(TopkTokenModel):
4255 rpm : RPM = field (default_factory = lambda : RPM (rpm = 1000 ))
4356 tpm : TPM = field (default_factory = lambda : TPM (tpm = 50000 ))
4457
45-
4658 def __post_init__ (self ):
4759 assert self .api_key is not None , "Please provide api key to access openai api."
48- if self .api_key == "" :
49- self .api_key = "none"
50- self . client = AsyncOpenAI ( api_key = self . api_key , base_url = self . base_url )
60+ self .client = AsyncOpenAI (
61+ api_key = self .api_key or "dummy" , base_url = self . base_url
62+ )
5163
5264 def _pre_generate (self , text : str , history : List [str ]) -> Dict :
5365 kwargs = {
@@ -69,16 +81,19 @@ def _pre_generate(self, text: str, history: List[str]) -> Dict:
6981 assert len (history ) % 2 == 0 , "History should have even number of elements."
7082 messages = history + messages
7183
72- kwargs [' messages' ] = messages
84+ kwargs [" messages" ] = messages
7385 return kwargs
7486
75-
7687 @retry (
7788 stop = stop_after_attempt (5 ),
7889 wait = wait_exponential (multiplier = 1 , min = 4 , max = 10 ),
79- retry = retry_if_exception_type ((RateLimitError , APIConnectionError , APITimeoutError )),
90+ retry = retry_if_exception_type (
91+ (RateLimitError , APIConnectionError , APITimeoutError )
92+ ),
8093 )
81- async def generate_topk_per_token (self , text : str , history : Optional [List [str ]] = None ) -> List [Token ]:
94+ async def generate_topk_per_token (
95+ self , text : str , history : Optional [List [str ]] = None
96+ ) -> List [Token ]:
8297 kwargs = self ._pre_generate (text , history )
8398 if self .topk_per_token > 0 :
8499 kwargs ["logprobs" ] = True
@@ -87,9 +102,8 @@ async def generate_topk_per_token(self, text: str, history: Optional[List[str]]
87102 # Limit max_tokens to 1 to avoid long completions
88103 kwargs ["max_tokens" ] = 1
89104
90- completion = await self .client .chat .completions .create ( # pylint: disable=E1125
91- model = self .model_name ,
92- ** kwargs
105+ completion = await self .client .chat .completions .create ( # pylint: disable=E1125
106+ model = self .model_name , ** kwargs
93107 )
94108
95109 tokens = get_top_response_tokens (completion )
@@ -99,32 +113,39 @@ async def generate_topk_per_token(self, text: str, history: Optional[List[str]]
99113 @retry (
100114 stop = stop_after_attempt (5 ),
101115 wait = wait_exponential (multiplier = 1 , min = 4 , max = 10 ),
102- retry = retry_if_exception_type ((RateLimitError , APIConnectionError , APITimeoutError )),
116+ retry = retry_if_exception_type (
117+ (RateLimitError , APIConnectionError , APITimeoutError )
118+ ),
103119 )
104- async def generate_answer (self , text : str , history : Optional [List [str ]] = None , temperature : int = 0 ) -> str :
120+ async def generate_answer (
121+ self , text : str , history : Optional [List [str ]] = None , temperature : int = 0
122+ ) -> str :
105123 kwargs = self ._pre_generate (text , history )
106124 kwargs ["temperature" ] = temperature
107125
108126 prompt_tokens = 0
109- for message in kwargs [' messages' ]:
110- prompt_tokens += len (Tokenizer ().encode_string (message [' content' ]))
111- estimated_tokens = prompt_tokens + kwargs [' max_tokens' ]
127+ for message in kwargs [" messages" ]:
128+ prompt_tokens += len (Tokenizer ().encode_string (message [" content" ]))
129+ estimated_tokens = prompt_tokens + kwargs [" max_tokens" ]
112130
113131 if self .request_limit :
114132 await self .rpm .wait (silent = True )
115133 await self .tpm .wait (estimated_tokens , silent = True )
116134
117- completion = await self .client .chat .completions .create ( # pylint: disable=E1125
118- model = self .model_name ,
119- ** kwargs
135+ completion = await self .client .chat .completions .create ( # pylint: disable=E1125
136+ model = self .model_name , ** kwargs
120137 )
121138 if hasattr (completion , "usage" ):
122- self .token_usage .append ({
123- "prompt_tokens" : completion .usage .prompt_tokens ,
124- "completion_tokens" : completion .usage .completion_tokens ,
125- "total_tokens" : completion .usage .total_tokens ,
126- })
127- return completion .choices [0 ].message .content
128-
129- async def generate_inputs_prob (self , text : str , history : Optional [List [str ]] = None ) -> List [Token ]:
139+ self .token_usage .append (
140+ {
141+ "prompt_tokens" : completion .usage .prompt_tokens ,
142+ "completion_tokens" : completion .usage .completion_tokens ,
143+ "total_tokens" : completion .usage .total_tokens ,
144+ }
145+ )
146+ return filter_think_tags (completion .choices [0 ].message .content )
147+
148+ async def generate_inputs_prob (
149+ self , text : str , history : Optional [List [str ]] = None
150+ ) -> List [Token ]:
130151 raise NotImplementedError
0 commit comments