Skip to content

Commit 561ce1b

Browse files
committed
feat: trigger task record
1 parent 16b2b2a commit 561ce1b

File tree

16 files changed

+349
-177
lines changed

16 files changed

+349
-177
lines changed

apps/application/chat_pipeline/I_base_chat_pipeline.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import time
1010
from abc import abstractmethod
1111
from typing import Type
12-
12+
import uuid_utils.compat as uuid
1313
from rest_framework import serializers
1414

1515
from knowledge.models import Paragraph
@@ -128,6 +128,8 @@ class IBaseChatPipelineStep:
128128
def __init__(self):
129129
# 当前步骤上下文,用于存储当前步骤信息
130130
self.context = {}
131+
self.status = 200
132+
self.err_message = ''
131133

132134
@abstractmethod
133135
def get_step_serializer(self, manage) -> Type[serializers.Serializer]:
@@ -145,12 +147,29 @@ def run(self, manage):
145147
:param manage: 步骤管理器
146148
:return: 执行结果
147149
"""
148-
start_time = time.time()
149-
self.context['start_time'] = start_time
150-
# 校验参数,
151-
self.valid_args(manage)
152-
self._run(manage)
153-
self.context['run_time'] = time.time() - start_time
150+
try:
151+
start_time = time.time()
152+
self.context['start_time'] = start_time
153+
# 校验参数,
154+
self.valid_args(manage)
155+
self._run(manage)
156+
self.context['run_time'] = time.time() - start_time
157+
except Exception as e:
158+
self.err_message = str(e)
159+
self.status = 500
160+
chat_record_id = manage.context.get('chat_record_id') or str(uuid.uuid7())
161+
manage.context['message_tokens'] = 0
162+
manage.context['answer_tokens'] = 0
163+
end_time = time.time()
164+
manage.context['run_time'] = end_time - (manage.context.get('start_time') or end_time)
165+
post_response_handler = manage.context.get('post_response_handler')
166+
post_response_handler.handler(manage.context.get('chat_id'), chat_record_id,
167+
manage.context.get('paragraph_list') or [],
168+
manage.context.get('problem_text'),
169+
str(e), manage, self, manage.context.get('padding_problem_text'),
170+
reasoning_content='')
171+
172+
raise e
154173

155174
def _run(self, manage):
156175
pass

apps/application/chat_pipeline/pipeline_manage.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self, step_list: List[Type[IBaseChatPipelineStep]],
2121
debug=False):
2222
# 步骤执行器
2323
self.step_list = [step() for step in step_list]
24+
self.run_step_list = []
2425
# 上下文
2526
self.context = {'message_tokens': 0, 'answer_tokens': 0}
2627
self.base_to_response = base_to_response
@@ -32,12 +33,13 @@ def run(self, context: Dict = None):
3233
for key, value in context.items():
3334
self.context[key] = value
3435
for step in self.step_list:
36+
self.run_step_list.append(step)
3537
step.run(self)
3638

3739
def get_details(self):
3840
return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in
3941
filter(lambda r: r is not None,
40-
[row.get_details(self) for row in self.step_list])], {})
42+
[row.get_details(self) for row in self.run_step_list])], {})
4143

4244
def get_base_to_response(self):
4345
return self.base_to_response

apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,10 @@ def execute(self, message_list: List[BaseMessage],
204204

205205
def get_details(self, manage, **kwargs):
206206
return {
207+
'status': self.status,
208+
'err_message': self.err_message,
207209
'step_type': 'chat_step',
208-
'run_time': self.context['run_time'],
210+
'run_time': self.context.get('run_time') or 0,
209211
'model_id': str(manage.context['model_id']),
210212
'message_list': self.reset_message_list(self.context['step_args'].get('message_list'),
211213
self.context['answer_text']),
@@ -273,7 +275,8 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
273275
if application_access_token is not None and application_access_token.authentication:
274276
raise AppApiException(
275277
500,
276-
_('Agent 【{name}】 access token authentication is not supported for agent tool').format(name=app.name)
278+
_('Agent 【{name}】 access token authentication is not supported for agent tool').format(
279+
name=app.name)
277280
)
278281
else:
279282
raise AppApiException(

apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,10 @@ def to_human_message(prompt: str,
7171
data_list.append(f"<data>{content}</data>")
7272
data = "\n".join(data_list)
7373
return HumanMessage(content=prompt.replace('{data}', data).replace('{question}', problem))
74+
75+
def get_details(self, manage, **kwargs):
76+
return {
77+
'status': self.status,
78+
'err_message': self.err_message,
79+
'step_type': 'generate_human_message',
80+
}

apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,14 @@ def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = Non
5656
return padding_problem
5757

5858
def get_details(self, manage, **kwargs):
59-
return {
60-
'step_type': 'problem_padding',
61-
'run_time': self.context['run_time'],
62-
'model_id': str(manage.context['model_id']) if 'model_id' in manage.context else None,
63-
'message_tokens': self.context.get('message_tokens', 0),
64-
'answer_tokens': self.context.get('answer_tokens', 0),
65-
'cost': 0,
66-
'padding_problem_text': self.context.get('padding_problem_text'),
67-
'problem_text': self.context.get("step_args").get('problem_text'),
68-
}
59+
return {'status': self.status,
60+
'err_message': self.err_message,
61+
'step_type': 'problem_padding',
62+
'run_time': self.context['run_time'],
63+
'model_id': str(manage.context['model_id']) if 'model_id' in manage.context else None,
64+
'message_tokens': self.context.get('message_tokens', 0),
65+
'answer_tokens': self.context.get('answer_tokens', 0),
66+
'cost': 0,
67+
'padding_problem_text': self.context.get('padding_problem_text'),
68+
'problem_text': self.context.get("step_args").get('problem_text'),
69+
}

apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_docum
6868
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
6969
embedding_value = embedding_model.embed_query(exec_problem_text)
7070
vector = VectorStore.get_embedding_vector()
71-
embedding_list = vector.query(exec_problem_text, embedding_value, knowledge_id_list, None, exclude_document_id_list,
71+
embedding_list = vector.query(exec_problem_text, embedding_value, knowledge_id_list, None,
72+
exclude_document_id_list,
7273
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
7374
if embedding_list is None:
7475
return []
@@ -132,12 +133,14 @@ def list_paragraph(embedding_list: List, vector):
132133
return paragraph_list
133134

134135
def get_details(self, manage, **kwargs):
135-
step_args = self.context['step_args']
136+
step_args = self.context.get('step_args') or {}
136137

137138
return {
139+
'status': self.status,
140+
'err_message': self.err_message,
138141
'step_type': 'search_step',
139-
'paragraph_list': [row.to_dict() for row in self.context['paragraph_list']],
140-
'run_time': self.context['run_time'],
142+
'paragraph_list': [row.to_dict() for row in (self.context.get('paragraph_list') or [])],
143+
'run_time': self.context.get('run_time') or 0,
141144
'problem_text': step_args.get(
142145
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
143146
'model_name': self.context.get('model_name'),

apps/application/serializers/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def append_chat_record(self, chat_record: ChatRecord):
263263
'improve_paragraph_id_list': chat_record.improve_paragraph_id_list,
264264
'run_time': chat_record.run_time,
265265
'source': chat_record.source,
266-
'ip_address': chat_record.ip_address,
266+
'ip_address': chat_record.ip_address or '',
267267
'index': chat_record.index},
268268
defaults={
269269
"vote_status": chat_record.vote_status,
@@ -278,7 +278,7 @@ def append_chat_record(self, chat_record: ChatRecord):
278278
'run_time': chat_record.run_time,
279279
'index': chat_record.index,
280280
'source': chat_record.source,
281-
'ip_address': chat_record.ip_address,
281+
'ip_address': chat_record.ip_address or '',
282282
})
283283
ChatCountSerializer(data={'chat_id': self.chat_id}).update_chat()
284284

apps/chat/serializers/chat.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from common.handle.base_to_response import BaseToResponse
3737
from common.handle.impl.response.openai_to_response import OpenaiToResponse
3838
from common.handle.impl.response.system_to_response import SystemToResponse
39-
from common.utils.common import flat_map, get_file_content
39+
from common.utils.common import flat_map, get_file_content, is_valid_uuid
4040
from knowledge.models import Document, Paragraph
4141
from maxkb.conf import PROJECT_DIR
4242
from models_provider.models import Model, Status
@@ -381,7 +381,8 @@ def get_chat_record(chat_info, chat_record_id):
381381
return chat_record_list[-1]
382382
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_info.chat_id).first()
383383
if chat_record is None:
384-
raise ChatException(500, _("Conversation record does not exist"))
384+
if not is_valid_uuid(chat_record_id):
385+
raise ChatException(500, _("Conversation record does not exist"))
385386
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id).first()
386387
return chat_record
387388

@@ -406,12 +407,13 @@ def chat_work_flow(self, chat_info: ChatInfo, instance: dict, base_to_response):
406407
history_chat_record = chat_info.chat_record_list
407408
if chat_record_id is not None:
408409
chat_record = self.get_chat_record(chat_info, chat_record_id)
409-
history_chat_record = [r for r in chat_info.chat_record_list if str(r.id) != chat_record_id]
410+
if chat_record:
411+
history_chat_record = [r for r in chat_info.chat_record_list if str(r.id) != chat_record_id]
410412
work_flow = chat_info.application.work_flow
411413
work_flow_manage = WorkflowManage(Workflow.new_instance(work_flow),
412414
{'history_chat_record': history_chat_record, 'question': message,
413415
'chat_id': chat_info.chat_id, 'chat_record_id': str(
414-
uuid.uuid7()) if chat_record is None else str(chat_record.id),
416+
uuid.uuid7()) if chat_record_id is None else str(chat_record_id),
415417
'stream': stream,
416418
're_chat': re_chat,
417419
'chat_user_id': chat_user_id,

apps/common/utils/common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,12 @@ def filter_special_character(_str):
351351
for t in s_list:
352352
_str = _str.replace(t, '')
353353
return _str
354+
355+
356+
def is_valid_uuid(uuid_string):
357+
"""判断字符串是否为有效的UUID"""
358+
try:
359+
uuid_obj = uuid.UUID(uuid_string)
360+
return str(uuid_obj) == uuid_string
361+
except ValueError:
362+
return False

0 commit comments

Comments
 (0)