11from datetime import datetime
22from enum import Enum
3- from typing import List , Optional , Union
3+ from typing import List , Optional , Any , Union
44
55from fastapi import Body
6+ from langchain_core .messages import SystemMessage , HumanMessage , AIMessage
67from pydantic import BaseModel
78from sqlalchemy import Column , Integer , Text , BigInteger , DateTime , Identity , Boolean
89from sqlalchemy import Enum as SQLAlchemyEnum
@@ -230,6 +231,7 @@ class AiModelQuestion(BaseModel):
230231 regenerate_record_id : Optional [int ] = None
231232
232233 def sql_sys_question (self , db_type : Union [str , DB ], enable_query_limit : bool = True ):
234+ templates : dict [str , str ] = {}
233235 _sql_template = get_sql_example_template (db_type )
234236 _base_template = get_sql_template ()
235237 _process_check = _sql_template .get ('process_check' ) if _sql_template .get ('process_check' ) else _base_template [
@@ -245,22 +247,37 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T
245247 'example_answer_2' ]
246248 _example_answer_3 = _sql_template ['example_answer_3_with_limit' ] if enable_query_limit else _sql_template [
247249 'example_answer_3' ]
248- return _base_template ['system' ].format (engine = self .engine , schema = self .db_schema , question = self .question ,
249- lang = self .lang , terminologies = self .terminologies ,
250- data_training = self .data_training , custom_prompt = self .custom_prompt ,
251- process_check = _process_check ,
252- base_sql_rules = _base_sql_rules ,
253- basic_sql_examples = _sql_examples ,
254- example_engine = _example_engine ,
255- example_answer_1 = _example_answer_1 ,
256- example_answer_2 = _example_answer_2 ,
257- example_answer_3 = _example_answer_3 )
250+
251+ templates ['system' ] = _base_template ['system' ].format (process_check = _process_check )
252+ templates ['rules' ] = _base_template ['generate_rules' ].format (lang = self .lang ,
253+ base_sql_rules = _base_sql_rules ,
254+ basic_sql_examples = _sql_examples ,
255+ example_engine = _example_engine ,
256+ example_answer_1 = _example_answer_1 ,
257+ example_answer_2 = _example_answer_2 ,
258+ example_answer_3 = _example_answer_3 )
259+ templates ['schema' ] = _base_template ['generate_basic_info' ].format (engine = self .engine , schema = self .db_schema )
260+
261+ if self .terminologies :
262+ templates ['terminologies' ] = _base_template ['generate_terminologies_info' ].format (
263+ terminologies = self .terminologies )
264+
265+ if self .data_training :
266+ templates ['data_training' ] = _base_template ['generate_data_training_info' ].format (
267+ data_training = self .data_training )
268+
269+ if self .custom_prompt :
270+ templates ['custom_prompt' ] = _base_template ['generate_custom_prompt_info' ].format (
271+ custom_prompt = self .custom_prompt )
272+
273+ return templates
258274
259275 def sql_user_question (self , current_time : str , change_title : bool ):
260276 _question = self .question
261277 if self .regenerate_record_id :
262278 _question = get_sql_template ()['regenerate_hint' ] + self .question
263- return get_sql_template ()['user' ].format (engine = self .engine , schema = self .db_schema , question = _question ,
279+ return get_sql_template ()['user' ].format (lang = self .lang , engine = self .engine , schema = self .db_schema ,
280+ question = _question ,
264281 rule = self .rule , current_time = current_time , error_msg = self .error_msg ,
265282 change_title = change_title )
266283
@@ -358,3 +375,30 @@ class McpAssistant(BaseModel):
358375 url : str = Body (description = '第三方数据接口' )
359376 authorization : str = Body (description = '第三方接口凭证' )
360377 stream : Optional [bool ] = Body (description = '是否流式输出,默认为true开启, 关闭false则返回JSON对象' , default = True )
378+
379+
380+ class SystemPromptMessage (SystemMessage ):
381+ sqlbot_system : bool = True
382+
383+ def __init__ (
384+ self , content : Union [str , list [Union [str , dict ]]], ** kwargs : Any
385+ ) -> None :
386+ super ().__init__ (content = content , ** kwargs )
387+
388+
389+ class HumanPromptMessage (HumanMessage ):
390+ sqlbot_system : bool = True
391+
392+ def __init__ (
393+ self , content : Union [str , list [Union [str , dict ]]], ** kwargs : Any
394+ ) -> None :
395+ super ().__init__ (content = content , ** kwargs )
396+
397+
398+ class AIPromptMessage (AIMessage ):
399+ sqlbot_system : bool = True
400+
401+ def __init__ (
402+ self , content : Union [str , list [Union [str , dict ]]], ** kwargs : Any
403+ ) -> None :
404+ super ().__init__ (content = content , ** kwargs )
0 commit comments