11# coding=utf-8
22import base64
3- import json
43from concurrent .futures import ThreadPoolExecutor
5- from requests .exceptions import ConnectTimeout , ReadTimeout
64from typing import Dict , Optional , Any , Iterator , cast , Union , Sequence , Callable , Mapping
75
86from langchain_core .language_models import LanguageModelInput
97from langchain_core .messages import BaseMessage , get_buffer_string , BaseMessageChunk , HumanMessageChunk , AIMessageChunk , \
108 SystemMessageChunk , FunctionMessageChunk , ChatMessageChunk
11- from langchain_core .messages .ai import UsageMetadata , AIMessage
9+ from langchain_core .messages .ai import UsageMetadata
1210from langchain_core .messages .tool import tool_call_chunk , ToolMessageChunk
1311from langchain_core .outputs import ChatGenerationChunk
1412from langchain_core .runnables import RunnableConfig , ensure_config
1513from langchain_core .tools import BaseTool
1614from langchain_openai import ChatOpenAI
1715from langchain_openai .chat_models .base import _create_usage_metadata
16+ from requests .exceptions import ReadTimeout
1817
1918from common .config .tokenizer_manage_config import TokenizerManage
2019from common .utils .logger import maxkb_logger
2120
21+
2222def custom_get_token_ids (text : str ):
2323 tokenizer = TokenizerManage .get_tokenizer ()
2424 return tokenizer .encode (text )
2525
26+
2627def _convert_delta_to_message_chunk (
27- _dict : Mapping [str , Any ], default_class : type [BaseMessageChunk ]
28+ _dict : Mapping [str , Any ], default_class : type [BaseMessageChunk ]
2829) -> BaseMessageChunk :
2930 """Convert to a LangChain message chunk."""
3031 id_ = _dict .get ("id" )
@@ -80,6 +81,7 @@ def _convert_delta_to_message_chunk(
8081 return ChatMessageChunk (content = content , role = role , id = id_ )
8182 return default_class (content = content , id = id_ ) # type: ignore[call-arg]#
8283
84+
8385class BaseChatOpenAI (ChatOpenAI ):
8486 usage_metadata : dict = {}
8587 custom_get_token_ids = custom_get_token_ids
@@ -219,64 +221,13 @@ def invoke(
219221 'token_usage' ] if 'token_usage' in chat_result .response_metadata else chat_result .usage_metadata
220222 return chat_result
221223
222- def _get_request_payload (
223- self ,
224- input_ : LanguageModelInput ,
225- * ,
226- stop : list [str ] | None = None ,
227- ** kwargs : Any ,
228- ) -> dict :
229- # Get original messages to preserve reasoning_content before base conversion
230- messages = self ._convert_input (input_ ).to_messages ()
231- # Store reasoning_content for AIMessages with tool_calls
232- # According to DeepSeek API docs, reasoning_content is REQUIRED when tool_calls
233- # are present during the tool invocation process (within same question/turn).
234- # See: https://api-docs.deepseek.com/guides/thinking_mode#tool-calls
235- reasoning_content_map = {}
236- for i , msg in enumerate (messages ):
237- if (
238- isinstance (msg , AIMessage )
239- and (msg .tool_calls or msg .invalid_tool_calls )
240- and (reasoning := msg .additional_kwargs .get ("reasoning_content" ))
241- ):
242- reasoning_content_map [i ] = reasoning
243-
244- payload = super ()._get_request_payload (input_ , stop = stop , ** kwargs )
245-
246- # Restore reasoning_content for assistant messages with tool_calls
247- # This is required by DeepSeek API - missing it causes 400 error
248- if "messages" in payload and reasoning_content_map :
249- for i , message in enumerate (payload ["messages" ]):
250- if (
251- i in reasoning_content_map
252- and message .get ("role" ) == "assistant"
253- and message .get ("tool_calls" )
254- ):
255- message ["reasoning_content" ] = reasoning_content_map [i ]
256-
257- # Apply DeepSeek-specific message formatting
258- for message in payload ["messages" ]:
259- if message ["role" ] == "tool" and isinstance (message ["content" ], list ):
260- message ["content" ] = json .dumps (message ["content" ])
261- elif message ["role" ] == "assistant" and isinstance (
262- message ["content" ], list
263- ):
264- # DeepSeek API expects assistant content to be a string, not a list.
265- # Extract text blocks and join them, or use empty string if none exist.
266- text_parts = [
267- block .get ("text" , "" )
268- for block in message ["content" ]
269- if isinstance (block , dict ) and block .get ("type" ) == "text"
270- ]
271- message ["content" ] = "" .join (text_parts ) if text_parts else ""
272- return payload
273-
274224 def upload_file_and_get_url (self , file_stream , file_name ):
275225 """上传文件并获取文件URL"""
276226 base64_video = base64 .b64encode (file_stream ).decode ("utf-8" )
277227 video_format = get_video_format (file_name )
278228 return f'data:{ video_format } ;base64,{ base64_video } '
279229
230+
280231def get_video_format (file_name ):
281232 extension = file_name .split ('.' )[- 1 ].lower ()
282233 format_map = {
@@ -285,4 +236,4 @@ def get_video_format(file_name):
285236 'mov' : 'video/mov' ,
286237 'wmv' : 'video/x-ms-wmv'
287238 }
288- return format_map .get (extension , 'video/mp4' )
239+ return format_map .get (extension , 'video/mp4' )
0 commit comments