Skip to content

Commit 3184b66

Browse files
committed
feat: support specifying datasource ID in MCP question
1 parent da807a5 commit 3184b66

File tree

5 files changed

+36
-9
lines changed

5 files changed

+36
-9
lines changed

backend/apps/chat/api/chat.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from sqlalchemy import and_, select
1111
from starlette.responses import JSONResponse
1212

13-
from apps.chat.curd.chat import delete_chat_with_user, get_chart_data_with_user, get_chat_predict_data_with_user, list_chats, get_chat_with_records, create_chat, rename_chat, \
13+
from apps.chat.curd.chat import delete_chat_with_user, get_chart_data_with_user, get_chat_predict_data_with_user, \
14+
list_chats, get_chat_with_records, create_chat, rename_chat, \
1415
delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \
15-
format_json_data, format_json_list_data, get_chart_config, list_recent_questions,get_chat as get_chat_exec, rename_chat_with_user
16+
format_json_data, format_json_list_data, get_chart_config, list_recent_questions, get_chat as get_chat_exec, \
17+
rename_chat_with_user
1618
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj, QuickCommand, \
1719
ChatInfo, Chat, ChatFinishStep
1820
from apps.chat.task.llm import LLMService
@@ -69,6 +71,7 @@ def inner():
6971
7072
return await asyncio.to_thread(inner) """
7173

74+
7275
@router.get("/record/{chat_record_id}/data", summary=f"{PLACEHOLDER_PREFIX}get_chart_data")
7376
async def chat_record_data(session: SessionDep, current_user: CurrentUser, chat_record_id: int):
7477
def inner():
@@ -81,7 +84,8 @@ def inner():
8184
@router.get("/record/{chat_record_id}/predict_data", summary=f"{PLACEHOLDER_PREFIX}get_chart_predict_data")
8285
async def chat_predict_data(session: SessionDep, current_user: CurrentUser, chat_record_id: int):
8386
def inner():
84-
data = get_chat_predict_data_with_user(chat_record_id=chat_record_id, session=session, current_user=current_user)
87+
data = get_chat_predict_data_with_user(chat_record_id=chat_record_id, session=session,
88+
current_user=current_user)
8589
return format_json_list_data(data)
8690

8791
return await asyncio.to_thread(inner)
@@ -102,6 +106,7 @@ async def rename(session: SessionDep, chat: RenameChat):
102106
detail=str(e)
103107
) """
104108

109+
105110
@router.post("/rename", response_model=str, summary=f"{PLACEHOLDER_PREFIX}rename_chat")
106111
@system_log(LogConfig(
107112
operation_type=OperationType.UPDATE,
@@ -117,6 +122,7 @@ async def rename(session: SessionDep, current_user: CurrentUser, chat: RenameCha
117122
detail=str(e)
118123
)
119124

125+
120126
""" @router.delete("/{chart_id}/{brief}", response_model=str, summary=f"{PLACEHOLDER_PREFIX}delete_chat")
121127
@system_log(LogConfig(
122128
operation_type=OperationType.DELETE,
@@ -133,6 +139,7 @@ async def delete(session: SessionDep, chart_id: int, brief: str):
133139
detail=str(e)
134140
) """
135141

142+
136143
@router.delete("/{chart_id}/{brief}", response_model=str, summary=f"{PLACEHOLDER_PREFIX}delete_chat")
137144
@system_log(LogConfig(
138145
operation_type=OperationType.DELETE,
@@ -149,6 +156,7 @@ async def delete(session: SessionDep, current_user: CurrentUser, chart_id: int,
149156
detail=str(e)
150157
)
151158

159+
152160
@router.post("/start", response_model=ChatInfo, summary=f"{PLACEHOLDER_PREFIX}start_chat")
153161
@require_permissions(permission=SqlbotPermission(type='ds', keyExpression="create_chat_obj.datasource"))
154162
@system_log(LogConfig(
@@ -172,9 +180,11 @@ async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat
172180
module=OperationModules.CHAT,
173181
result_id_expr="id"
174182
))
175-
async def start_chat(session: SessionDep, current_user: CurrentUser, current_assistant: CurrentAssistant, create_chat_obj: CreateChat = CreateChat(origin=2)):
183+
async def start_chat(session: SessionDep, current_user: CurrentUser, current_assistant: CurrentAssistant,
184+
create_chat_obj: CreateChat = CreateChat(origin=2)):
176185
try:
177-
return create_chat(session, current_user, create_chat_obj, create_chat_obj and create_chat_obj.datasource, current_assistant)
186+
return create_chat(session, current_user, create_chat_obj, create_chat_obj and create_chat_obj.datasource,
187+
current_assistant)
178188
except Exception as e:
179189
raise HTTPException(
180190
status_code=500,
@@ -213,7 +223,7 @@ def _err(_e: Exception):
213223

214224
@router.get("/recent_questions/{datasource_id}", response_model=List[str],
215225
summary=f"{PLACEHOLDER_PREFIX}get_recommend_questions")
216-
#@require_permissions(permission=SqlbotPermission(type='ds', keyExpression="datasource_id"))
226+
# @require_permissions(permission=SqlbotPermission(type='ds', keyExpression="datasource_id"))
217227
async def recommend_questions(session: SessionDep, current_user: CurrentUser,
218228
datasource_id: int = Path(..., description=f"{PLACEHOLDER_PREFIX}ds_id")):
219229
return list_recent_questions(session=session, current_user=current_user, datasource_id=datasource_id)
@@ -442,8 +452,8 @@ def _err(_e: Exception):
442452

443453

444454
@router.get("/record/{chat_record_id}/excel/export/{chat_id}", summary=f"{PLACEHOLDER_PREFIX}export_chart_data")
445-
@system_log(LogConfig(operation_type=OperationType.EXPORT,module=OperationModules.CHAT,resource_id_expr="chat_id",))
446-
async def export_excel(session: SessionDep, current_user: CurrentUser, chat_record_id: int,chat_id: int, trans: Trans):
455+
@system_log(LogConfig(operation_type=OperationType.EXPORT, module=OperationModules.CHAT, resource_id_expr="chat_id", ))
456+
async def export_excel(session: SessionDep, current_user: CurrentUser, chat_record_id: int, chat_id: int, trans: Trans):
447457
chat_record = session.get(ChatRecord, chat_record_id)
448458
if not chat_record:
449459
raise HTTPException(

backend/apps/chat/curd/chat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,8 @@ def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj:
493493
ds.type_name = DB.get_db(ds.type)
494494
else:
495495
ds = session.get(CoreDatasource, create_chat_obj.datasource)
496+
if ds.oid != current_user.oid:
497+
raise Exception(f"Datasource with id {create_chat_obj.datasource} does not belong to current workspace")
496498

497499
if not ds:
498500
raise Exception(f"Datasource with id {create_chat_obj.datasource} not found")

backend/apps/chat/models/chat_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def dynamic_user_question(self):
282282

283283
class ChatQuestion(AiModelQuestion):
284284
chat_id: int
285+
datasource_id: Optional[int] = None
285286

286287

287288
class ChatMcp(ChatQuestion):
@@ -299,6 +300,7 @@ class McpQuestion(BaseModel):
299300
token: str = Body(description='token')
300301
stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True)
301302
lang: Optional[str] = Body(description='语言:zh-CN|en|ko-KR', default='zh-CN')
303+
datasource_id: Optional[int] = Body(description='数据源ID,仅当当前对话没有确定数据源时有效', default=None)
302304

303305

304306
class AxisObj(BaseModel):

backend/apps/chat/task/llm.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,19 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
107107
if not chat:
108108
raise SingleMessageError(f"Chat with id {chat_id} not found")
109109
ds: CoreDatasource | AssistantOutDsSchema | None = None
110+
if not chat.datasource and chat_question.datasource_id:
111+
_ds = session.get(CoreDatasource, chat_question.datasource_id)
112+
if _ds:
113+
if _ds.oid != current_user.oid:
114+
raise SingleMessageError(f"Datasource with id {chat_question.datasource_id} does not belong to current workspace")
115+
chat.datasource = _ds.id
116+
chat.engine_type = _ds.type_name
117+
# save chat
118+
session.add(chat)
119+
session.flush()
120+
session.refresh(chat)
121+
session.commit()
122+
110123
if chat.datasource:
111124
# Get available datasource
112125
if current_assistant and current_assistant.type in dynamic_ds_types:

backend/apps/mcp/mcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ async def mcp_start(session: SessionDep, chat: ChatStart):
114114
async def mcp_question(session: SessionDep, chat: McpQuestion):
115115
session_user = get_user(session, chat.token)
116116

117-
mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question)
117+
mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question, datasource_id=chat.datasource_id)
118118

119119
return await question_answer_inner(session=session, current_user=session_user, request_question=mcp_chat,
120120
in_chat=False, stream=chat.stream)

0 commit comments

Comments
 (0)