refactor: update ZhipuChatModel to use BaseChatOpenAI and improve token counting#4291
refactor: update ZhipuChatModel to use BaseChatOpenAI and improve token counting#4291
Conversation
…en counting --bug=1061305 --user=刘瑞斌 【应用】ai对话启用工具后部分模型(智谱)不统计tokens https://www.tapd.cn/62980211/s/1791683
|
Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it. DetailsInstructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes-sigs/prow repository. |
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here. DetailsNeeds approval from an approver in each of these files:Approvers can indicate their approval by writing |
| return super().get_num_tokens(text) | ||
| except Exception as e: | ||
| tokenizer = TokenizerManage.get_tokenizer() | ||
| return len(tokenizer.encode(text)) |
There was a problem hiding this comment.
No significant changes have been identified in the provided code snippet, but here are some minor improvements or clarifications:
-
Remove
self.optional_paramsassignment as it's not used anywhere in the class. -
Adjust the usage of token counting methods to use consistent approach between classes.
Here's a revised version of the code with these corrections applied:
# Import necessary libraries and classes
import json
from typing import Dict, List
from langchain_core.messages import BaseMessage
from common.config.tokenizer_manage_config import TokenizerManage
class ZhipuChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def is_cache_model():
pass
def __init__(self, *args, max_length=None, **kwargs):
super().__init__(*args, **kwargs)
self.max_length = max_length
def _token_count(self, text: str) -> int:
# Use the default method provided by BaseChatOpenAI if available
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
tokens = tokenizer.encode(text)
return min(len(tokens), self.max_length) if max_length else len(tokens)
def generate_responses(self, messages: List[BaseMessage]) -> List[str]:
total_tokens = 0
responses = []
for message in messages:
token_count = self._token_count(message.content)
total_tokens += token_count
# Perform your logic here based on whether you want to split or concatenate input data
# Return all generated responses or handle further processingKey Changes:
-
Initialization: Added
max_lengthparameter during initialization and set it to an instance variable. -
Token Count Calculation: Modified
_token_countmethod to first call the superclass’s method (super().get_num_tokens(text)), which should be more reliable than manually encoding texts via TokenizerManage (since it might have additional optimizations). -
Replaced Custom Method: Removed the custom
custom_get_token_idsmethod since we can now directly get a list of IDs using the same technique.
This change will ensure that the token count reflects the actual input length rather than potentially over-counting due to the use of multiple encoders. This may make future modifications easier without changing the core logic too much.
refactor: update ZhipuChatModel to use BaseChatOpenAI and improve token counting --bug=1061305 --user=刘瑞斌 【应用】ai对话启用工具后部分模型(智谱)不统计tokens https://www.tapd.cn/62980211/s/1791683