diff --git a/apps/application/flow/step_node/reranker_node/i_reranker_node.py b/apps/application/flow/step_node/reranker_node/i_reranker_node.py index 2829eb50e5f..af87a6f2003 100644 --- a/apps/application/flow/step_node/reranker_node/i_reranker_node.py +++ b/apps/application/flow/step_node/reranker_node/i_reranker_node.py @@ -31,7 +31,9 @@ class RerankerStepNodeSerializer(serializers.Serializer): reranker_setting = RerankerSettingSerializer(required=True) question_reference_address = serializers.ListField(required=True) - reranker_model_id = serializers.UUIDField(required=True) + reranker_model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True) + reranker_model_id_type = serializers.CharField(required=False, default='custom') + reranker_model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True) reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True)) show_knowledge = serializers.BooleanField(required=True, label=_("The results are displayed in the knowledge sources")) @@ -55,9 +57,27 @@ def _run(self): reference[0], reference[1:]) for reference in self.node_params_serializer.data.get('reranker_reference_list')] - return self.execute(**self.node_params_serializer.data, question=str(question), - reranker_list=reranker_list) + node_params_data = dict(self.node_params_serializer.data) + + reranker_model_id_type = node_params_data.pop('reranker_model_id_type', None) + reranker_model_id_reference = node_params_data.pop('reranker_model_id_reference', None) + reranker_model_id = node_params_data.pop('reranker_model_id', None) + + # 处理引用类型 + if reranker_model_id_type == 'reference' and reranker_model_id_reference: + reference_data = self.workflow_manage.get_reference_field( + reranker_model_id_reference[0], + reranker_model_id_reference[1:], + ) + if reference_data and isinstance(reference_data, dict): + reranker_model_id = reference_data.get('reranker_model_id', + reference_data.get('model_id', reranker_model_id)) + if reranker_model_id is None or reranker_model_id == '': + raise Exception(_('Model is not allowed to be empty')) + + return self.execute(**node_params_data, question=str(question), + reranker_list=reranker_list, reranker_model_id=reranker_model_id) def execute(self, question, reranker_setting, reranker_list, reranker_model_id, show_knowledge, **kwargs) -> NodeResult: diff --git a/ui/src/workflow/nodes/reranker-node/index.vue b/ui/src/workflow/nodes/reranker-node/index.vue index ff8e481f8b5..efcb2759bb3 100644 --- a/ui/src/workflow/nodes/reranker-node/index.vue +++ b/ui/src/workflow/nodes/reranker-node/index.vue @@ -122,32 +122,62 @@ - +
+ +
+
+ >() const form = { reranker_reference_list: [[]], reranker_model_id: '', + reranker_model_id_type: 'custom', + reranker_model_id_reference: [], question_reference_address: [], reranker_setting: { top_n: 3, @@ -222,6 +254,12 @@ const wheel = (e: any) => { const form_data = computed({ get: () => { if (props.nodeModel.properties.node_data) { + if (!props.nodeModel.properties.node_data.reranker_model_id_type) { + set(props.nodeModel.properties.node_data, 'reranker_model_id_type', 'custom') + } + if (!props.nodeModel.properties.node_data.reranker_model_id_reference) { + set(props.nodeModel.properties.node_data, 'reranker_model_id_reference', []) + } return props.nodeModel.properties.node_data } else { set(props.nodeModel.properties, 'node_data', form) @@ -232,10 +270,13 @@ const form_data = computed({ set(props.nodeModel.properties, 'node_data', value) }, }) + function refreshParam(data: any) { set(props.nodeModel.properties.node_data, 'reranker_setting', data) } +const modelCascaderRef = ref() + const resource = getResourceDetail() function getSelectModel() { const obj = @@ -264,6 +305,7 @@ const nodeCascaderRef = ref() const validate = () => { return Promise.all([ nodeCascaderRef.value ? nodeCascaderRef.value.validate() : Promise.resolve(''), + modelCascaderRef.value ? modelCascaderRef.value.validate() : Promise.resolve(''), rerankerNodeFormRef.value?.validate(), ]).catch((err: any) => { return Promise.reject({ node: props.nodeModel, errMessage: err })