-
Notifications
You must be signed in to change notification settings - Fork 690
Expand file tree
/
Copy pathchat_model.py
More file actions
303 lines (243 loc) · 13 KB
/
chat_model.py
File metadata and controls
303 lines (243 loc) · 13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
from datetime import datetime
from enum import Enum
from typing import List, Optional, Union
from fastapi import Body
from pydantic import BaseModel
from sqlalchemy import Column, Integer, Text, BigInteger, DateTime, Identity, Boolean
from sqlalchemy import Enum as SQLAlchemyEnum
from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import SQLModel, Field
from apps.db.constant import DB
from apps.template.filter.generator import get_permissions_template
from apps.template.generate_analysis.generator import get_analysis_template
from apps.template.generate_chart.generator import get_chart_template
from apps.template.generate_dynamic.generator import get_dynamic_template
from apps.template.generate_guess_question.generator import get_guess_question_template
from apps.template.generate_predict.generator import get_predict_template
from apps.template.generate_sql.generator import get_sql_template, get_sql_example_template
from apps.template.prase_sql.generator import get_prase_sql_template
from apps.template.select_datasource.generator import get_datasource_template
def enum_values(enum_class: type[Enum]) -> list:
"""Get values for enum."""
return [status.value for status in enum_class]
class TypeEnum(Enum):
CHAT = "0"
# TODO other usage
class OperationEnum(Enum):
GENERATE_SQL = '0'
GENERATE_CHART = '1'
ANALYSIS = '2'
PREDICT_DATA = '3'
GENERATE_RECOMMENDED_QUESTIONS = '4'
GENERATE_SQL_WITH_PERMISSIONS = '5'
CHOOSE_DATASOURCE = '6'
GENERATE_DYNAMIC_SQL = '7'
PRASE_SQL = '8'
class ChatFinishStep(Enum):
GENERATE_SQL = 1
QUERY_DATA = 2
GENERATE_CHART = 3
# TODO choose table / check connection / generate description
class ChatLog(SQLModel, table=True):
__tablename__ = "chat_log"
id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True))
type: TypeEnum = Field(
sa_column=Column(SQLAlchemyEnum(TypeEnum, native_enum=False, values_callable=enum_values, length=3)))
operate: OperationEnum = Field(
sa_column=Column(SQLAlchemyEnum(OperationEnum, native_enum=False, values_callable=enum_values, length=3)))
pid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True))
ai_modal_id: Optional[int] = Field(sa_column=Column(BigInteger))
base_modal: Optional[str] = Field(max_length=255)
messages: Optional[list[dict]] = Field(sa_column=Column(JSONB))
reasoning_content: Optional[str | None] = Field(sa_column=Column(Text, nullable=True))
start_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
finish_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
token_usage: Optional[dict | None | int] = Field(sa_column=Column(JSONB))
class Chat(SQLModel, table=True):
__tablename__ = "chat"
id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True))
oid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True, default=1))
create_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
create_by: int = Field(sa_column=Column(BigInteger, nullable=True))
brief: str = Field(max_length=64, nullable=True)
chat_type: str = Field(max_length=20, default="chat") # chat, datasource
datasource: int = Field(sa_column=Column(BigInteger, nullable=True))
engine_type: str = Field(max_length=64)
origin: Optional[int] = Field(
sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant
class ChatRecord(SQLModel, table=True):
__tablename__ = "chat_record"
id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True))
chat_id: int = Field(sa_column=Column(BigInteger, nullable=False))
ai_modal_id: Optional[int] = Field(sa_column=Column(BigInteger))
first_chat: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
create_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
finish_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
create_by: int = Field(sa_column=Column(BigInteger, nullable=True))
datasource: int = Field(sa_column=Column(BigInteger, nullable=True))
engine_type: str = Field(max_length=64, nullable=True)
question: str = Field(sa_column=Column(Text, nullable=True))
sql_answer: str = Field(sa_column=Column(Text, nullable=True))
sql: str = Field(sa_column=Column(Text, nullable=True))
sql_exec_result: str = Field(sa_column=Column(Text, nullable=True))
data: str = Field(sa_column=Column(Text, nullable=True))
chart_answer: str = Field(sa_column=Column(Text, nullable=True))
chart: str = Field(sa_column=Column(Text, nullable=True))
analysis: str = Field(sa_column=Column(Text, nullable=True))
predict: str = Field(sa_column=Column(Text, nullable=True))
predict_data: str = Field(sa_column=Column(Text, nullable=True))
recommended_question_answer: str = Field(sa_column=Column(Text, nullable=True))
recommended_question: str = Field(sa_column=Column(Text, nullable=True))
datasource_select_answer: str = Field(sa_column=Column(Text, nullable=True))
finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
error: str = Field(sa_column=Column(Text, nullable=True))
analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
predict_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
sql_prase: str = Field(sa_column=Column(Text, nullable=True))
class ChatRecordResult(BaseModel):
id: Optional[int] = None
chat_id: Optional[int] = None
ai_modal_id: Optional[int] = None
first_chat: bool = False
create_time: Optional[datetime] = None
finish_time: Optional[datetime] = None
question: Optional[str] = None
sql_answer: Optional[str] = None
sql: Optional[str] = None
data: Optional[str] = None
chart_answer: Optional[str] = None
chart: Optional[str] = None
analysis: Optional[str] = None
predict: Optional[str] = None
predict_data: Optional[str] = None
recommended_question: Optional[str] = None
datasource_select_answer: Optional[str] = None
finish: Optional[bool] = None
error: Optional[str] = None
analysis_record_id: Optional[int] = None
predict_record_id: Optional[int] = None
sql_reasoning_content: Optional[str] = None
chart_reasoning_content: Optional[str] = None
analysis_reasoning_content: Optional[str] = None
predict_reasoning_content: Optional[str] = None
class CreateChat(BaseModel):
id: int = None
question: str = None
datasource: int = None
origin: Optional[int] = 0 # 0是页面上,mcp是1,小助手是2
class RenameChat(BaseModel):
id: int = None
brief: str = ''
class ChatInfo(BaseModel):
id: Optional[int] = None
create_time: datetime = None
create_by: int = None
brief: str = ''
chat_type: str = "chat"
datasource: Optional[int] = None
engine_type: str = ''
ds_type: str = ''
datasource_name: str = ''
datasource_exists: bool = True
records: List[ChatRecord | dict] = []
class AiModelQuestion(BaseModel):
question: str = None
ai_modal_id: int = None
ai_modal_name: str = None # Specific model name
engine: str = ""
db_schema: str = ""
sql: str = ""
rule: str = ""
fields: str = ""
data: str = ""
lang: str = "简体中文"
filter: str = []
sub_query: Optional[list[dict]] = None
terminologies: str = ""
data_training: str = ""
custom_prompt: str = ""
error_msg: str = ""
def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = True):
_sql_template = get_sql_example_template(db_type)
_base_sql_rules = _sql_template['quot_rule'] + _sql_template['limit_rule'] + _sql_template['other_rule']
_query_limit = get_sql_template()['query_limit'] if enable_query_limit else ''
_sql_examples = _sql_template['basic_example']
_example_engine = _sql_template['example_engine']
_example_answer_1 = _sql_template['example_answer_1_with_limit'] if enable_query_limit else _sql_template[
'example_answer_1']
_example_answer_2 = _sql_template['example_answer_2_with_limit'] if enable_query_limit else _sql_template[
'example_answer_2']
_example_answer_3 = _sql_template['example_answer_3_with_limit'] if enable_query_limit else _sql_template[
'example_answer_3']
return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question,
lang=self.lang, terminologies=self.terminologies,
data_training=self.data_training, custom_prompt=self.custom_prompt,
base_sql_rules=_base_sql_rules, query_limit=_query_limit,
basic_sql_examples=_sql_examples,
example_engine=_example_engine,
example_answer_1=_example_answer_1,
example_answer_2=_example_answer_2,
example_answer_3=_example_answer_3)
def sql_user_question(self, current_time: str):
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,
rule=self.rule, current_time=current_time, error_msg=self.error_msg)
def chart_sys_question(self):
return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang)
def chart_user_question(self, chart_type: Optional[str] = None):
return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule,
chart_type=chart_type)
def analysis_sys_question(self):
return get_analysis_template()['system'].format(lang=self.lang, terminologies=self.terminologies,
custom_prompt=self.custom_prompt)
def analysis_user_question(self):
return get_analysis_template()['user'].format(fields=self.fields, data=self.data)
def predict_sys_question(self):
return get_predict_template()['system'].format(lang=self.lang, custom_prompt=self.custom_prompt)
def predict_user_question(self):
return get_predict_template()['user'].format(fields=self.fields, data=self.data)
def prase_sql_sys_question(self):
return get_prase_sql_template()['system'].format(lang=self.lang)
def prase_sql_user_question(self,sql='',chart=''):
return get_prase_sql_template()['user'].format(sql=sql,chart=chart)
def datasource_sys_question(self):
return get_datasource_template()['system'].format(lang=self.lang)
def datasource_user_question(self, datasource_list: str = "[]"):
return get_datasource_template()['user'].format(question=self.question, data=datasource_list)
def guess_sys_question(self):
return get_guess_question_template()['system'].format(lang=self.lang)
def guess_user_question(self, old_questions: str = "[]"):
return get_guess_question_template()['user'].format(question=self.question, schema=self.db_schema,
old_questions=old_questions)
def filter_sys_question(self):
return get_permissions_template()['system'].format(lang=self.lang, engine=self.engine)
def filter_user_question(self):
return get_permissions_template()['user'].format(sql=self.sql, filter=self.filter)
def dynamic_sys_question(self):
return get_dynamic_template()['system'].format(lang=self.lang, engine=self.engine)
def dynamic_user_question(self):
return get_dynamic_template()['user'].format(sql=self.sql, sub_query=self.sub_query)
class ChatQuestion(AiModelQuestion):
chat_id: int
class ChatMcp(ChatQuestion):
token: str
class ChatStart(BaseModel):
username: str = Body(description='用户名')
password: str = Body(description='密码')
class McpQuestion(BaseModel):
question: str = Body(description='用户提问')
chat_id: int = Body(description='会话ID')
token: str = Body(description='token')
stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True)
class AxisObj(BaseModel):
name: str = ''
value: str = ''
type: str | None = None
class ExcelData(BaseModel):
axis: list[AxisObj] = []
data: list[dict] = []
name: str = 'Excel'
class McpAssistant(BaseModel):
question: str = Body(description='用户提问')
url: str = Body(description='第三方数据接口')
authorization: str = Body(description='第三方接口凭证')
stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True)