Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions apps/application/flow/step_node/reranker_node/i_reranker_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ class RerankerStepNodeSerializer(serializers.Serializer):
reranker_setting = RerankerSettingSerializer(required=True)

question_reference_address = serializers.ListField(required=True)
reranker_model_id = serializers.UUIDField(required=True)
reranker_model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True)
reranker_model_id_type = serializers.CharField(required=False, default='custom')
reranker_model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True)
reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True))
show_knowledge = serializers.BooleanField(required=True,
label=_("The results are displayed in the knowledge sources"))
Expand All @@ -55,9 +57,27 @@ def _run(self):
reference[0],
reference[1:]) for reference in
self.node_params_serializer.data.get('reranker_reference_list')]
return self.execute(**self.node_params_serializer.data, question=str(question),

reranker_list=reranker_list)
node_params_data = dict(self.node_params_serializer.data)

reranker_model_id_type = node_params_data.pop('reranker_model_id_type', None)
reranker_model_id_reference = node_params_data.pop('reranker_model_id_reference', None)
reranker_model_id = node_params_data.pop('reranker_model_id', None)

# 处理引用类型
if reranker_model_id_type == 'reference' and reranker_model_id_reference:
reference_data = self.workflow_manage.get_reference_field(
reranker_model_id_reference[0],
reranker_model_id_reference[1:],
)
if reference_data and isinstance(reference_data, dict):
reranker_model_id = reference_data.get('reranker_model_id',
reference_data.get('model_id', reranker_model_id))
if reranker_model_id is None or reranker_model_id == '':
raise Exception(_('Model is not allowed to be empty'))

return self.execute(**node_params_data, question=str(question),
reranker_list=reranker_list, reranker_model_id=reranker_model_id)

def execute(self, question, reranker_setting, reranker_list, reranker_model_id, show_knowledge,
**kwargs) -> NodeResult:
Expand Down
74 changes: 58 additions & 16 deletions ui/src/workflow/nodes/reranker-node/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -122,32 +122,62 @@
</el-form-item>
<el-form-item
:label="$t('workflow.nodes.rerankerNode.reranker_model.label')"
prop="reranker_model_id"
:prop="
form_data.reranker_model_id_type === 'reference'
? 'reranker_model_id_reference'
: 'reranker_model_id'
"
:rules="{
required: true,
message: $t('workflow.nodes.rerankerNode.reranker_model.placeholder'),
message:
form_data.reranker_model_id_type === 'reference'
? $t('workflow.variable.placeholder')
: $t('workflow.nodes.rerankerNode.reranker_model.placeholder'),
trigger: 'change',
}"
>
<template #label>
<div class="flex-between">
<span
>{{ $t('workflow.nodes.rerankerNode.reranker_model.label')
}}<span class="color-danger">*</span></span
<div class="flex-between w-full">
<div>
<span
>{{ $t('workflow.nodes.rerankerNode.reranker_model.label')
}}<span class="color-danger">*</span></span
>
</div>
<el-select
v-model="form_data.reranker_model_id_type"
:teleported="false"
size="small"
style="width: 85px"
@change="form_data.reranker_model_id_reference = []"
>
<el-option :label="$t('workflow.variable.Referencing')" value="reference" />
<el-option :label="$t('common.custom')" value="custom" />
</el-select>
</div>
</template>
<ModelSelect
@wheel="wheel"
:teleported="false"
v-model="form_data.reranker_model_id"
:placeholder="$t('workflow.nodes.rerankerNode.reranker_model.placeholder')"
:options="modelOptions"
@submitModel="getSelectModel"
showFooter
:model-type="'RERANKER'"
></ModelSelect>
<div class="flex-between w-full" v-if="form_data.reranker_model_id_type !== 'reference'">
<ModelSelect
@wheel="wheel"
:teleported="false"
v-model="form_data.reranker_model_id"
:placeholder="$t('workflow.nodes.rerankerNode.reranker_model.placeholder')"
:options="modelOptions"
@submitModel="getSelectModel"
showFooter
:model-type="'RERANKER'"
></ModelSelect>
</div>
<NodeCascader
v-else
ref="modelCascaderRef"
:nodeModel="nodeModel"
class="w-full"
:placeholder="$t('workflow.variable.placeholder')"
v-model="form_data.reranker_model_id_reference"
/>
</el-form-item>

<el-form-item
:label="$t('workflow.nodes.searchKnowledgeNode.showKnowledge.label')"
prop="show_knowledge"
Expand Down Expand Up @@ -192,6 +222,8 @@ const ParamSettingDialogRef = ref<InstanceType<typeof ParamSettingDialog>>()
const form = {
reranker_reference_list: [[]],
reranker_model_id: '',
reranker_model_id_type: 'custom',
reranker_model_id_reference: [],
question_reference_address: [],
reranker_setting: {
top_n: 3,
Expand Down Expand Up @@ -222,6 +254,12 @@ const wheel = (e: any) => {
const form_data = computed({
get: () => {
if (props.nodeModel.properties.node_data) {
if (!props.nodeModel.properties.node_data.reranker_model_id_type) {
set(props.nodeModel.properties.node_data, 'reranker_model_id_type', 'custom')
}
if (!props.nodeModel.properties.node_data.reranker_model_id_reference) {
set(props.nodeModel.properties.node_data, 'reranker_model_id_reference', [])
}
return props.nodeModel.properties.node_data
} else {
set(props.nodeModel.properties, 'node_data', form)
Expand All @@ -232,10 +270,13 @@ const form_data = computed({
set(props.nodeModel.properties, 'node_data', value)
},
})

function refreshParam(data: any) {
set(props.nodeModel.properties.node_data, 'reranker_setting', data)
}

const modelCascaderRef = ref()

const resource = getResourceDetail()
function getSelectModel() {
const obj =
Expand Down Expand Up @@ -264,6 +305,7 @@ const nodeCascaderRef = ref()
const validate = () => {
return Promise.all([
nodeCascaderRef.value ? nodeCascaderRef.value.validate() : Promise.resolve(''),
modelCascaderRef.value ? modelCascaderRef.value.validate() : Promise.resolve(''),
rerankerNodeFormRef.value?.validate(),
]).catch((err: any) => {
return Promise.reject({ node: props.nodeModel, errMessage: err })
Expand Down
Loading