Skip to content

Commit d0f827e

Browse files
committed
feat: improve generate sql quality
1 parent f75770f commit d0f827e

File tree

4 files changed

+164
-42
lines changed

4 files changed

+164
-42
lines changed

backend/apps/chat/models/chat_model.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from datetime import datetime
22
from enum import Enum
3-
from typing import List, Optional, Union
3+
from typing import List, Optional, Any, Union
44

55
from fastapi import Body
6+
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
67
from pydantic import BaseModel
78
from sqlalchemy import Column, Integer, Text, BigInteger, DateTime, Identity, Boolean
89
from sqlalchemy import Enum as SQLAlchemyEnum
@@ -230,6 +231,7 @@ class AiModelQuestion(BaseModel):
230231
regenerate_record_id: Optional[int] = None
231232

232233
def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = True):
234+
templates: dict[str, str] = {}
233235
_sql_template = get_sql_example_template(db_type)
234236
_base_template = get_sql_template()
235237
_process_check = _sql_template.get('process_check') if _sql_template.get('process_check') else _base_template[
@@ -245,22 +247,37 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T
245247
'example_answer_2']
246248
_example_answer_3 = _sql_template['example_answer_3_with_limit'] if enable_query_limit else _sql_template[
247249
'example_answer_3']
248-
return _base_template['system'].format(engine=self.engine, schema=self.db_schema, question=self.question,
249-
lang=self.lang, terminologies=self.terminologies,
250-
data_training=self.data_training, custom_prompt=self.custom_prompt,
251-
process_check=_process_check,
252-
base_sql_rules=_base_sql_rules,
253-
basic_sql_examples=_sql_examples,
254-
example_engine=_example_engine,
255-
example_answer_1=_example_answer_1,
256-
example_answer_2=_example_answer_2,
257-
example_answer_3=_example_answer_3)
250+
251+
templates['system'] = _base_template['system'].format(process_check=_process_check)
252+
templates['rules'] = _base_template['generate_rules'].format(lang=self.lang,
253+
base_sql_rules=_base_sql_rules,
254+
basic_sql_examples=_sql_examples,
255+
example_engine=_example_engine,
256+
example_answer_1=_example_answer_1,
257+
example_answer_2=_example_answer_2,
258+
example_answer_3=_example_answer_3)
259+
templates['schema'] = _base_template['generate_basic_info'].format(engine=self.engine, schema=self.db_schema)
260+
261+
if self.terminologies:
262+
templates['terminologies'] = _base_template['generate_terminologies_info'].format(
263+
terminologies=self.terminologies)
264+
265+
if self.data_training:
266+
templates['data_training'] = _base_template['generate_data_training_info'].format(
267+
data_training=self.data_training)
268+
269+
if self.custom_prompt:
270+
templates['custom_prompt'] = _base_template['generate_custom_prompt_info'].format(
271+
custom_prompt=self.custom_prompt)
272+
273+
return templates
258274

259275
def sql_user_question(self, current_time: str, change_title: bool):
260276
_question = self.question
261277
if self.regenerate_record_id:
262278
_question = get_sql_template()['regenerate_hint'] + self.question
263-
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=_question,
279+
return get_sql_template()['user'].format(lang=self.lang, engine=self.engine, schema=self.db_schema,
280+
question=_question,
264281
rule=self.rule, current_time=current_time, error_msg=self.error_msg,
265282
change_title=change_title)
266283

@@ -358,3 +375,30 @@ class McpAssistant(BaseModel):
358375
url: str = Body(description='第三方数据接口')
359376
authorization: str = Body(description='第三方接口凭证')
360377
stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True)
378+
379+
380+
class SystemPromptMessage(SystemMessage):
381+
sqlbot_system: bool = True
382+
383+
def __init__(
384+
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
385+
) -> None:
386+
super().__init__(content=content, **kwargs)
387+
388+
389+
class HumanPromptMessage(HumanMessage):
390+
sqlbot_system: bool = True
391+
392+
def __init__(
393+
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
394+
) -> None:
395+
super().__init__(content=content, **kwargs)
396+
397+
398+
class AIPromptMessage(AIMessage):
399+
sqlbot_system: bool = True
400+
401+
def __init__(
402+
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
403+
) -> None:
404+
super().__init__(content=content, **kwargs)

0 commit comments

Comments
 (0)