-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Expand file tree
/
Copy pathi_reranker_node.py
More file actions
84 lines (66 loc) · 3.81 KB
/
i_reranker_node.py
File metadata and controls
84 lines (66 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: i_reranker_node.py
@date:2024/9/4 10:40
@desc:
"""
from typing import Type
from rest_framework import serializers
from application.flow.common import WorkflowMode
from application.flow.i_step_node import INode, NodeResult
from django.utils.translation import gettext_lazy as _
class RerankerSettingSerializer(serializers.Serializer):
# 需要查询的条数
top_n = serializers.IntegerField(required=True,
label=_("Reference segment number"))
# 相似度 0-1之间
similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
label=_("Reference segment number"))
max_paragraph_char_number = serializers.IntegerField(required=True,
label=_("Maximum number of words in a quoted segment"))
class RerankerStepNodeSerializer(serializers.Serializer):
reranker_setting = RerankerSettingSerializer(required=True)
question_reference_address = serializers.ListField(required=True)
reranker_model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True)
reranker_model_id_type = serializers.CharField(required=False, default='custom')
reranker_model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True)
reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True))
show_knowledge = serializers.BooleanField(required=True,
label=_("The results are displayed in the knowledge sources"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
class IRerankerNode(INode):
type = 'reranker-node'
support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP, WorkflowMode.TOOL, WorkflowMode.TOOL_LOOP]
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return RerankerStepNodeSerializer
def _run(self):
question = self.workflow_manage.get_reference_field(
self.node_params_serializer.data.get('question_reference_address')[0],
self.node_params_serializer.data.get('question_reference_address')[1:])
reranker_list = [self.workflow_manage.get_reference_field(
reference[0],
reference[1:]) for reference in
self.node_params_serializer.data.get('reranker_reference_list')]
node_params_data = dict(self.node_params_serializer.data)
reranker_model_id_type = node_params_data.pop('reranker_model_id_type', None)
reranker_model_id_reference = node_params_data.pop('reranker_model_id_reference', None)
reranker_model_id = node_params_data.pop('reranker_model_id', None)
# 处理引用类型
if reranker_model_id_type == 'reference' and reranker_model_id_reference:
reference_data = self.workflow_manage.get_reference_field(
reranker_model_id_reference[0],
reranker_model_id_reference[1:],
)
if reference_data and isinstance(reference_data, dict):
reranker_model_id = reference_data.get('reranker_model_id',
reference_data.get('model_id', reranker_model_id))
if reranker_model_id is None or reranker_model_id == '':
raise Exception(_('Model is not allowed to be empty'))
return self.execute(**node_params_data, question=str(question),
reranker_list=reranker_list, reranker_model_id=reranker_model_id)
def execute(self, question, reranker_setting, reranker_list, reranker_model_id, show_knowledge,
**kwargs) -> NodeResult:
pass