Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 56 additions & 12 deletions backend/apps/chat/models/chat_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datetime import datetime
from enum import Enum
from typing import List, Optional, Union
from typing import List, Optional, Any, Union

from fastapi import Body
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from pydantic import BaseModel
from sqlalchemy import Column, Integer, Text, BigInteger, DateTime, Identity, Boolean
from sqlalchemy import Enum as SQLAlchemyEnum
Expand Down Expand Up @@ -230,6 +231,7 @@ class AiModelQuestion(BaseModel):
regenerate_record_id: Optional[int] = None

def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = True):
templates: dict[str, str] = {}
_sql_template = get_sql_example_template(db_type)
_base_template = get_sql_template()
_process_check = _sql_template.get('process_check') if _sql_template.get('process_check') else _base_template[
Expand All @@ -245,22 +247,37 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T
'example_answer_2']
_example_answer_3 = _sql_template['example_answer_3_with_limit'] if enable_query_limit else _sql_template[
'example_answer_3']
return _base_template['system'].format(engine=self.engine, schema=self.db_schema, question=self.question,
lang=self.lang, terminologies=self.terminologies,
data_training=self.data_training, custom_prompt=self.custom_prompt,
process_check=_process_check,
base_sql_rules=_base_sql_rules,
basic_sql_examples=_sql_examples,
example_engine=_example_engine,
example_answer_1=_example_answer_1,
example_answer_2=_example_answer_2,
example_answer_3=_example_answer_3)

templates['system'] = _base_template['system'].format(process_check=_process_check)
templates['rules'] = _base_template['generate_rules'].format(lang=self.lang,
base_sql_rules=_base_sql_rules,
basic_sql_examples=_sql_examples,
example_engine=_example_engine,
example_answer_1=_example_answer_1,
example_answer_2=_example_answer_2,
example_answer_3=_example_answer_3)
templates['schema'] = _base_template['generate_basic_info'].format(engine=self.engine, schema=self.db_schema)

if self.terminologies:
templates['terminologies'] = _base_template['generate_terminologies_info'].format(
terminologies=self.terminologies)

if self.data_training:
templates['data_training'] = _base_template['generate_data_training_info'].format(
data_training=self.data_training)

if self.custom_prompt:
templates['custom_prompt'] = _base_template['generate_custom_prompt_info'].format(
custom_prompt=self.custom_prompt)

return templates

def sql_user_question(self, current_time: str, change_title: bool):
_question = self.question
if self.regenerate_record_id:
_question = get_sql_template()['regenerate_hint'] + self.question
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=_question,
return get_sql_template()['user'].format(lang=self.lang, engine=self.engine, schema=self.db_schema,
question=_question,
rule=self.rule, current_time=current_time, error_msg=self.error_msg,
change_title=change_title)

Expand Down Expand Up @@ -358,3 +375,30 @@ class McpAssistant(BaseModel):
url: str = Body(description='第三方数据接口')
authorization: str = Body(description='第三方接口凭证')
stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True)


class SystemPromptMessage(SystemMessage):
sqlbot_system: bool = True

def __init__(
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None:
super().__init__(content=content, **kwargs)


class HumanPromptMessage(HumanMessage):
sqlbot_system: bool = True

def __init__(
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None:
super().__init__(content=content, **kwargs)


class AIPromptMessage(AIMessage):
sqlbot_system: bool = True

def __init__(
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None:
super().__init__(content=content, **kwargs)
Loading
Loading