|
22 | 22 | from application.models import ChatRecord, ChatUserType |
23 | 23 | from common.field.common import InstanceField |
24 | 24 | from knowledge.models.knowledge_action import KnowledgeAction, State |
| 25 | +from tools.models import ToolRecord |
25 | 26 |
|
26 | 27 | chat_cache = cache |
27 | 28 |
|
@@ -115,6 +116,40 @@ def handler(self, workflow): |
115 | 116 | 'start_time') is not None else 0) |
116 | 117 |
|
117 | 118 |
|
| 119 | +def get_tool_workflow_state(workflow): |
| 120 | + if workflow.is_the_task_interrupted(): |
| 121 | + return State.REVOKED |
| 122 | + details = workflow.get_runtime_details() |
| 123 | + node_list = details.values() |
| 124 | + all_node = [*node_list, *get_loop_workflow_node(node_list)] |
| 125 | + err = any([True for value in all_node if value.get('status') == 500 and not value.get('enableException')]) |
| 126 | + if err: |
| 127 | + return State.FAILURE |
| 128 | + return State.SUCCESS |
| 129 | + |
| 130 | + |
| 131 | +class ToolWorkflowPostHandler(WorkFlowPostHandler): |
| 132 | + def __init__(self, chat_info, tool_id): |
| 133 | + super().__init__(chat_info) |
| 134 | + self.tool_id = tool_id |
| 135 | + |
| 136 | + def handler(self, workflow): |
| 137 | + state = get_tool_workflow_state(workflow) |
| 138 | + record = ToolRecord(id=self.chat_info.tool_record_id, tool_id=self.tool_id, |
| 139 | + workspace_id=self.chat_info.workspace_id, |
| 140 | + source_type=self.chat_info.source_type, |
| 141 | + source_id=self.chat_info.source_id, |
| 142 | + state=state, |
| 143 | + meta={ |
| 144 | + 'output': workflow.out_context, |
| 145 | + 'details': workflow.get_runtime_details(), |
| 146 | + 'answer_text_list': workflow.get_answer_text_list() |
| 147 | + }) |
| 148 | + self.chat_info.set_record(record) |
| 149 | + self.chat_info = None |
| 150 | + self.tool_id = None |
| 151 | + |
| 152 | + |
118 | 153 | def get_loop_workflow_node(node_list): |
119 | 154 | result = [] |
120 | 155 | for item in node_list: |
@@ -204,6 +239,11 @@ class KnowledgeFlowParamsSerializer(serializers.Serializer): |
204 | 239 | knowledge_base = serializers.DictField(required=False, label="知识库设置") |
205 | 240 |
|
206 | 241 |
|
| 242 | +class ToolFlowParamsSerializer(serializers.Serializer): |
| 243 | + tool_id = serializers.UUIDField(required=True, label="工具id") |
| 244 | + workspace_id = serializers.CharField(required=True, label="工作空间id") |
| 245 | + |
| 246 | + |
207 | 247 | class INode: |
208 | 248 | view_type = 'many_view' |
209 | 249 |
|
|
0 commit comments