diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 53dfe7a0cd4..1dc23e58281 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -187,7 +187,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record self.context['model_setting'] = model_setting workspace_id = self.workflow_manage.get_body().get('workspace_id') chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id, - **model_params_setting) + **(model_params_setting or {})) history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type, self.runtime_node_id) self.context['history_message'] = [{'content': message.content, 'role': message.type} for message in diff --git a/apps/application/flow/step_node/image_to_video_step_node/i_image_to_video_node.py b/apps/application/flow/step_node/image_to_video_step_node/i_image_to_video_node.py index adc6731ee0e..f3aee5ef5c1 100644 --- a/apps/application/flow/step_node/image_to_video_step_node/i_image_to_video_node.py +++ b/apps/application/flow/step_node/image_to_video_step_node/i_image_to_video_node.py @@ -10,7 +10,10 @@ class ImageToVideoNodeSerializer(serializers.Serializer): - model_id = serializers.CharField(required=True, label=_("Model id")) + model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id")) + model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type")) + model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True, + label=_("Reference Field")) prompt = serializers.CharField(required=True, label=_("Prompt word (positive)")) @@ -69,5 +72,6 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t model_params_setting, chat_record_id, first_frame_url, last_frame_url, + model_id_type=None, model_id_reference=None, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py b/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py index 2bc6a2a02ca..895068c9bb2 100644 --- a/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py +++ b/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py @@ -29,10 +29,21 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t model_params_setting, chat_record_id, first_frame_url, last_frame_url=None, + model_id_type=None, model_id_reference=None, **kwargs) -> NodeResult: + # 处理引用类型 + if model_id_type == 'reference' and model_id_reference: + reference_data = self.workflow_manage.get_reference_field( + model_id_reference[0], + model_id_reference[1:], + ) + if reference_data and isinstance(reference_data, dict): + model_id = reference_data.get('model_id', model_id) + model_params_setting = reference_data.get('model_params_setting') + workspace_id = self.workflow_manage.get_body().get('workspace_id') ttv_model = get_model_instance_by_model_workspace_id(model_id, workspace_id, - **model_params_setting) + **(model_params_setting or {})) history_message = self.get_history_message(history_chat_record, dialogue_number) self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) diff --git a/apps/application/flow/step_node/intent_node/impl/base_intent_node.py b/apps/application/flow/step_node/intent_node/impl/base_intent_node.py index d458ac8268f..cdf8a769abf 100644 --- a/apps/application/flow/step_node/intent_node/impl/base_intent_node.py +++ b/apps/application/flow/step_node/intent_node/impl/base_intent_node.py @@ -71,7 +71,7 @@ def execute(self, model_id, dialogue_number, history_chat_record, user_input, br # 获取模型实例 workspace_id = self.workflow_manage.get_body().get('workspace_id') chat_model = get_model_instance_by_model_workspace_id( - model_id, workspace_id, **model_params_setting + model_id, workspace_id, **(model_params_setting or {}) ) # 获取历史对话 diff --git a/apps/application/flow/step_node/parameter_extraction_node/i_parameter_extraction_node.py b/apps/application/flow/step_node/parameter_extraction_node/i_parameter_extraction_node.py index f36cd4c26ce..54c60bb096c 100644 --- a/apps/application/flow/step_node/parameter_extraction_node/i_parameter_extraction_node.py +++ b/apps/application/flow/step_node/parameter_extraction_node/i_parameter_extraction_node.py @@ -19,7 +19,10 @@ class VariableSplittingNodeParamsSerializer(serializers.Serializer): model_params_setting = serializers.DictField(required=False, label=_("Model parameter settings")) - model_id = serializers.CharField(required=True, label=_("Model id")) + model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id")) + model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type")) + model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True, + label=_("Reference Field")) class IParameterExtractionNode(INode): @@ -31,12 +34,25 @@ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: return VariableSplittingNodeParamsSerializer def _run(self): + model_id_type = self.node_params_serializer.data.get('model_id_type') + model_id_reference = self.node_params_serializer.data.get('model_id_reference') + model_id = self.node_params_serializer.data.get('model_id') + model_params_setting = self.node_params_serializer.data.get('model_params_setting') + # 处理引用类型 + if model_id_type == 'reference' and model_id_reference: + reference_data = self.workflow_manage.get_reference_field( + model_id_reference[0], + model_id_reference[1:], + ) + if reference_data and isinstance(reference_data, dict): + model_id = reference_data.get('model_id', model_id) + model_params_setting = reference_data.get('model_params_setting') + input_variable = self.workflow_manage.get_reference_field( self.node_params_serializer.data.get('input_variable')[0], self.node_params_serializer.data.get('input_variable')[1:]) return self.execute(input_variable, self.node_params_serializer.data['variable_list'], - self.node_params_serializer.data['model_params_setting'], - self.node_params_serializer.data['model_id']) + model_params_setting, model_id) def execute(self, input_variable, variable_list, model_params_setting, model_id, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/parameter_extraction_node/impl/base_parameter_extraction_node.py b/apps/application/flow/step_node/parameter_extraction_node/impl/base_parameter_extraction_node.py index 0d64c6b1d11..c23e0512e8b 100644 --- a/apps/application/flow/step_node/parameter_extraction_node/impl/base_parameter_extraction_node.py +++ b/apps/application/flow/step_node/parameter_extraction_node/impl/base_parameter_extraction_node.py @@ -94,11 +94,12 @@ def save_context(self, details, workflow_manage): def execute(self, input_variable, variable_list, model_params_setting, model_id, **kwargs) -> NodeResult: input_variable = str(input_variable) self.context['request'] = input_variable - if model_params_setting is None: + if model_params_setting is None and model_id: model_params_setting = get_default_model_params_setting(model_id) workspace_id = self.workflow_manage.get_body().get('workspace_id') chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id, - **model_params_setting) + **(model_params_setting or {})) + content = generate_content(input_variable, variable_list) response = chat_model.invoke([HumanMessage(content=content)]) result = json_loads(response.content, variable_list) diff --git a/apps/application/flow/step_node/question_node/i_question_node.py b/apps/application/flow/step_node/question_node/i_question_node.py index 63d5454a4bd..c083df91e7f 100644 --- a/apps/application/flow/step_node/question_node/i_question_node.py +++ b/apps/application/flow/step_node/question_node/i_question_node.py @@ -16,7 +16,10 @@ class QuestionNodeSerializer(serializers.Serializer): - model_id = serializers.CharField(required=True, label=_("Model id")) + model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id")) + model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type")) + model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True, + label=_("Reference Field")) system = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Role Setting")) prompt = serializers.CharField(required=True, label=_("Prompt word")) @@ -42,6 +45,6 @@ def _run(self): return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, - model_params_setting=None, + model_params_setting=None, model_id_type=None, model_id_reference=None, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index d2ccad823f0..2a83493a534 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -83,13 +83,23 @@ def save_context(self, details, workflow_manage): self.answer_text = details.get('answer') def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, - model_params_setting=None, + model_params_setting=None, model_id_type=None, model_id_reference=None, **kwargs) -> NodeResult: - if model_params_setting is None: + # 处理引用类型 + if model_id_type == 'reference' and model_id_reference: + reference_data = self.workflow_manage.get_reference_field( + model_id_reference[0], + model_id_reference[1:], + ) + if reference_data and isinstance(reference_data, dict): + model_id = reference_data.get('model_id', model_id) + model_params_setting = reference_data.get('model_params_setting') + + if model_params_setting is None and model_id: model_params_setting = get_default_model_params_setting(model_id) workspace_id = self.workflow_manage.get_body().get('workspace_id') chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id, - **model_params_setting) + **(model_params_setting or {})) history_message = self.get_history_message(history_chat_record, dialogue_number) self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) diff --git a/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py index d593e3b1f48..cd68e315d4b 100644 --- a/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py +++ b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py @@ -69,7 +69,7 @@ def execute(self, tts_model_id, self.context['content'] = chunk workspace_id = self.workflow_manage.get_body().get('workspace_id') model = get_model_instance_by_model_workspace_id( - tts_model_id, workspace_id, **model_params_setting) + tts_model_id, workspace_id, **(model_params_setting or {})) audio_byte = model.text_to_speech(chunk) diff --git a/apps/application/flow/step_node/video_understand_step_node/i_video_understand_node.py b/apps/application/flow/step_node/video_understand_step_node/i_video_understand_node.py index 25d2971959c..fb4c7b5fa55 100644 --- a/apps/application/flow/step_node/video_understand_step_node/i_video_understand_node.py +++ b/apps/application/flow/step_node/video_understand_step_node/i_video_understand_node.py @@ -10,7 +10,10 @@ class VideoUnderstandNodeSerializer(serializers.Serializer): - model_id = serializers.CharField(required=True, label=_("Model id")) + model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id")) + model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type")) + model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True, + label=_("Reference Field")) system = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Role Setting")) prompt = serializers.CharField(required=True, label=_("Prompt word")) @@ -52,5 +55,6 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist model_params_setting, chat_record_id, video, + model_id_type=None, model_id_reference=None, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/video_understand_step_node/impl/base_video_understand_node.py b/apps/application/flow/step_node/video_understand_step_node/impl/base_video_understand_node.py index 3d75d980be1..3952ca9175d 100644 --- a/apps/application/flow/step_node/video_understand_step_node/impl/base_video_understand_node.py +++ b/apps/application/flow/step_node/video_understand_step_node/impl/base_video_understand_node.py @@ -75,10 +75,21 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist model_params_setting, chat_record_id, video, + model_id_type=None, model_id_reference=None, **kwargs) -> NodeResult: + # 处理引用类型 + if model_id_type == 'reference' and model_id_reference: + reference_data = self.workflow_manage.get_reference_field( + model_id_reference[0], + model_id_reference[1:], + ) + if reference_data and isinstance(reference_data, dict): + model_id = reference_data.get('model_id', model_id) + model_params_setting = reference_data.get('model_params_setting') + workspace_id = self.workflow_manage.get_body().get('workspace_id') video_model = get_model_instance_by_model_workspace_id(model_id, workspace_id, - **model_params_setting) + **(model_params_setting or {})) # 执行详情中的历史消息不需要图片内容 history_message = self.get_history_message_for_details(history_chat_record, dialogue_number) self.context['history_message'] = history_message diff --git a/ui/src/workflow/nodes/image-to-video/index.vue b/ui/src/workflow/nodes/image-to-video/index.vue index 1c9fd57d2f5..cbb51a7313d 100644 --- a/ui/src/workflow/nodes/image-to-video/index.vue +++ b/ui/src/workflow/nodes/image-to-video/index.vue @@ -13,10 +13,13 @@ > @@ -28,31 +31,51 @@ }}* + + + + + + + + + - + + + - - - + + - { const props = defineProps<{ nodeModel: any }>() const modelOptions = ref(null) const AIModeParamSettingDialogRef = ref>() +const nodeCascaderRef = ref() const aiChatNodeFormRef = ref() const validate = () => { - return aiChatNodeFormRef.value?.validate().catch((err) => { + return Promise.all([ + nodeCascaderRef.value ? nodeCascaderRef.value.validate() : Promise.resolve(''), + aiChatNodeFormRef.value?.validate(), + ]).catch((err: any) => { return Promise.reject({ node: props.nodeModel, errMessage: err }) }) } @@ -246,6 +274,8 @@ const defaultPrompt = `{{${t('workflow.nodes.startNode.label')}.question}}` const form = { model_id: '', + model_id_type: 'custom', + model_id_reference: [], system: '', prompt: defaultPrompt, negative_prompt: '', @@ -261,6 +291,12 @@ const form = { const form_data = computed({ get: () => { if (props.nodeModel.properties.node_data) { + if (!props.nodeModel.properties.node_data.model_id_type) { + set(props.nodeModel.properties.node_data, 'model_id_type', 'custom') + } + if (!props.nodeModel.properties.node_data.model_id_reference) { + set(props.nodeModel.properties.node_data, 'model_id_reference', []) + } return props.nodeModel.properties.node_data } else { set(props.nodeModel.properties, 'node_data', form) diff --git a/ui/src/workflow/nodes/parameter-extraction-node/index.vue b/ui/src/workflow/nodes/parameter-extraction-node/index.vue index 0e7369be1ab..0a54f890f52 100644 --- a/ui/src/workflow/nodes/parameter-extraction-node/index.vue +++ b/ui/src/workflow/nodes/parameter-extraction-node/index.vue @@ -13,10 +13,13 @@ > @@ -28,29 +31,50 @@ }}* - + + + + + + + + + - + + + - - + + { if (props.nodeModel.properties.node_data) { + if (!props.nodeModel.properties.node_data.model_id_type) { + set(props.nodeModel.properties.node_data, 'model_id_type', 'custom') + } + if (!props.nodeModel.properties.node_data.model_id_reference) { + set(props.nodeModel.properties.node_data, 'model_id_reference', []) + } return props.nodeModel.properties.node_data } else { set(props.nodeModel.properties, 'node_data', form) @@ -187,8 +219,13 @@ const model_change = (model_id?: string) => { } const VariableSplittingRef = ref() +const nodeCascaderRef = ref() + const validate = async () => { - return VariableSplittingRef.value.validate().catch((err: any) => { + return Promise.all([ + nodeCascaderRef.value ? nodeCascaderRef.value.validate() : Promise.resolve(''), + VariableSplittingRef.value.validate(), + ]).catch((err: any) => { return Promise.reject({ node: props.nodeModel, errMessage: err }) }) } diff --git a/ui/src/workflow/nodes/question-node/index.vue b/ui/src/workflow/nodes/question-node/index.vue index c140ba0c0fe..e2f8fb18492 100644 --- a/ui/src/workflow/nodes/question-node/index.vue +++ b/ui/src/workflow/nodes/question-node/index.vue @@ -13,43 +13,68 @@ > - + {{ $t('views.application.form.aiModel.label') }}* + + + + + + + + + - + + + - - + + @@ -152,7 +177,7 @@