Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
self.context['model_setting'] = model_setting
workspace_id = self.workflow_manage.get_body().get('workspace_id')
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting)
**(model_params_setting or {}))
history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
self.runtime_node_id)
self.context['history_message'] = [{'content': message.content, 'role': message.type} for message in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@


class ImageToVideoNodeSerializer(serializers.Serializer):
model_id = serializers.CharField(required=True, label=_("Model id"))
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id"))
model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type"))
model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True,
label=_("Reference Field"))

prompt = serializers.CharField(required=True, label=_("Prompt word (positive)"))

Expand Down Expand Up @@ -69,5 +72,6 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t
model_params_setting,
chat_record_id,
first_frame_url, last_frame_url,
model_id_type=None, model_id_reference=None,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,21 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t
model_params_setting,
chat_record_id,
first_frame_url, last_frame_url=None,
model_id_type=None, model_id_reference=None,
**kwargs) -> NodeResult:
# 处理引用类型
if model_id_type == 'reference' and model_id_reference:
reference_data = self.workflow_manage.get_reference_field(
model_id_reference[0],
model_id_reference[1:],
)
if reference_data and isinstance(reference_data, dict):
model_id = reference_data.get('model_id', model_id)
model_params_setting = reference_data.get('model_params_setting')

workspace_id = self.workflow_manage.get_body().get('workspace_id')
ttv_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting)
**(model_params_setting or {}))
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def execute(self, model_id, dialogue_number, history_chat_record, user_input, br
# 获取模型实例
workspace_id = self.workflow_manage.get_body().get('workspace_id')
chat_model = get_model_instance_by_model_workspace_id(
model_id, workspace_id, **model_params_setting
model_id, workspace_id, **(model_params_setting or {})
)

# 获取历史对话
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ class VariableSplittingNodeParamsSerializer(serializers.Serializer):
model_params_setting = serializers.DictField(required=False,
label=_("Model parameter settings"))

model_id = serializers.CharField(required=True, label=_("Model id"))
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id"))
model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type"))
model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True,
label=_("Reference Field"))


class IParameterExtractionNode(INode):
Expand All @@ -31,12 +34,25 @@ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return VariableSplittingNodeParamsSerializer

def _run(self):
model_id_type = self.node_params_serializer.data.get('model_id_type')
model_id_reference = self.node_params_serializer.data.get('model_id_reference')
model_id = self.node_params_serializer.data.get('model_id')
model_params_setting = self.node_params_serializer.data.get('model_params_setting')
# 处理引用类型
if model_id_type == 'reference' and model_id_reference:
reference_data = self.workflow_manage.get_reference_field(
model_id_reference[0],
model_id_reference[1:],
)
if reference_data and isinstance(reference_data, dict):
model_id = reference_data.get('model_id', model_id)
model_params_setting = reference_data.get('model_params_setting')

input_variable = self.workflow_manage.get_reference_field(
self.node_params_serializer.data.get('input_variable')[0],
self.node_params_serializer.data.get('input_variable')[1:])
return self.execute(input_variable, self.node_params_serializer.data['variable_list'],
self.node_params_serializer.data['model_params_setting'],
self.node_params_serializer.data['model_id'])
model_params_setting, model_id)

def execute(self, input_variable, variable_list, model_params_setting, model_id, **kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ def save_context(self, details, workflow_manage):
def execute(self, input_variable, variable_list, model_params_setting, model_id, **kwargs) -> NodeResult:
input_variable = str(input_variable)
self.context['request'] = input_variable
if model_params_setting is None:
if model_params_setting is None and model_id:
model_params_setting = get_default_model_params_setting(model_id)
workspace_id = self.workflow_manage.get_body().get('workspace_id')
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting)
**(model_params_setting or {}))

content = generate_content(input_variable, variable_list)
response = chat_model.invoke([HumanMessage(content=content)])
result = json_loads(response.content, variable_list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@


class QuestionNodeSerializer(serializers.Serializer):
model_id = serializers.CharField(required=True, label=_("Model id"))
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id"))
model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type"))
model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True,
label=_("Reference Field"))
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
label=_("Role Setting"))
prompt = serializers.CharField(required=True, label=_("Prompt word"))
Expand All @@ -42,6 +45,6 @@ def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)

def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting=None,
model_params_setting=None, model_id_type=None, model_id_reference=None,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,23 @@ def save_context(self, details, workflow_manage):
self.answer_text = details.get('answer')

def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting=None,
model_params_setting=None, model_id_type=None, model_id_reference=None,
**kwargs) -> NodeResult:
if model_params_setting is None:
# 处理引用类型
if model_id_type == 'reference' and model_id_reference:
reference_data = self.workflow_manage.get_reference_field(
model_id_reference[0],
model_id_reference[1:],
)
if reference_data and isinstance(reference_data, dict):
model_id = reference_data.get('model_id', model_id)
model_params_setting = reference_data.get('model_params_setting')

if model_params_setting is None and model_id:
model_params_setting = get_default_model_params_setting(model_id)
workspace_id = self.workflow_manage.get_body().get('workspace_id')
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting)
**(model_params_setting or {}))
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def execute(self, tts_model_id,
self.context['content'] = chunk
workspace_id = self.workflow_manage.get_body().get('workspace_id')
model = get_model_instance_by_model_workspace_id(
tts_model_id, workspace_id, **model_params_setting)
tts_model_id, workspace_id, **(model_params_setting or {}))

audio_byte = model.text_to_speech(chunk)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@


class VideoUnderstandNodeSerializer(serializers.Serializer):
model_id = serializers.CharField(required=True, label=_("Model id"))
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id"))
model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type"))
model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True,
label=_("Reference Field"))
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
label=_("Role Setting"))
prompt = serializers.CharField(required=True, label=_("Prompt word"))
Expand Down Expand Up @@ -52,5 +55,6 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist
model_params_setting,
chat_record_id,
video,
model_id_type=None, model_id_reference=None,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,21 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist
model_params_setting,
chat_record_id,
video,
model_id_type=None, model_id_reference=None,
**kwargs) -> NodeResult:
# 处理引用类型
if model_id_type == 'reference' and model_id_reference:
reference_data = self.workflow_manage.get_reference_field(
model_id_reference[0],
model_id_reference[1:],
)
if reference_data and isinstance(reference_data, dict):
model_id = reference_data.get('model_id', model_id)
model_params_setting = reference_data.get('model_params_setting')

workspace_id = self.workflow_manage.get_body().get('workspace_id')
video_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting)
**(model_params_setting or {}))
# 执行详情中的历史消息不需要图片内容
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
Expand Down
76 changes: 56 additions & 20 deletions ui/src/workflow/nodes/image-to-video/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
>
<el-form-item
:label="$t('workflow.nodes.imageToVideoGenerate.model.label')"
prop="model_id"
:prop="form_data.model_id_type === 'reference' ? 'model_id_reference' : 'model_id'"
:rules="{
required: true,
message: $t('workflow.nodes.imageToVideoGenerate.model.requiredMessage'),
message:
form_data.model_id_type === 'reference'
? $t('workflow.variable.placeholder')
: $t('workflow.nodes.imageToVideoGenerate.model.requiredMessage'),
trigger: 'change',
}"
>
Expand All @@ -28,31 +31,51 @@
}}<span class="color-danger">*</span></span
>
</div>
<el-select
v-model="form_data.model_id_type"
:teleported="false"
size="small"
style="width: 85px"
@change="form_data.model_id_reference = []"
>
<el-option :label="$t('workflow.variable.Referencing')" value="reference" />
<el-option :label="$t('common.custom')" value="custom" />
</el-select>
</div>
</template>
<div class="flex-between w-full" v-if="form_data.model_id_type !== 'reference'">
<ModelSelect
@change="model_change"
@wheel="wheel"
:teleported="false"
v-model="form_data.model_id"
@focus="getSelectModel"
:placeholder="$t('workflow.nodes.imageToVideoGenerate.model.requiredMessage')"
:options="modelOptions"
showFooter
:model-type="'ITV'"
></ModelSelect>
<div class="ml-8">
<el-button
:disabled="!form_data.model_id"
type="primary"
link
@click="openAIParamSettingDialog(form_data.model_id)"
@refreshForm="refreshParam"
>
<AppIcon iconName="app-setting"></AppIcon>
<el-icon>
<Operation />
</el-icon>
</el-button>
</div>
</template>

<ModelSelect
@change="model_change"
@wheel="wheel"
:teleported="false"
v-model="form_data.model_id"
@focus="getSelectModel"
:placeholder="$t('workflow.nodes.imageToVideoGenerate.model.requiredMessage')"
:options="modelOptions"
showFooter
:model-type="'ITV'"
></ModelSelect>
</div>
<NodeCascader
v-else
ref="nodeCascaderRef"
:nodeModel="nodeModel"
class="w-full"
:placeholder="$t('workflow.variable.placeholder')"
v-model="form_data.model_id_reference"
/>
</el-form-item>

<el-form-item
:label="$t('workflow.nodes.imageToVideoGenerate.prompt.label')"
prop="prompt"
Expand Down Expand Up @@ -203,6 +226,7 @@ import { useRoute } from 'vue-router'
import { loadSharedApi } from '@/utils/dynamics-api/shared-api'
import NodeCascader from '@/workflow/common/NodeCascader.vue'
import { WorkflowMode } from '@/enums/application'

const workflowMode = (inject('workflowMode') as WorkflowMode) || WorkflowMode.Application
const getResourceDetail = inject('getResourceDetail') as any
const route = useRoute()
Expand All @@ -224,10 +248,14 @@ const apiType = computed(() => {
const props = defineProps<{ nodeModel: any }>()
const modelOptions = ref<any>(null)
const AIModeParamSettingDialogRef = ref<InstanceType<typeof AIModeParamSettingDialog>>()
const nodeCascaderRef = ref()

const aiChatNodeFormRef = ref<FormInstance>()
const validate = () => {
return aiChatNodeFormRef.value?.validate().catch((err) => {
return Promise.all([
nodeCascaderRef.value ? nodeCascaderRef.value.validate() : Promise.resolve(''),
aiChatNodeFormRef.value?.validate(),
]).catch((err: any) => {
return Promise.reject({ node: props.nodeModel, errMessage: err })
})
}
Expand All @@ -246,6 +274,8 @@ const defaultPrompt = `{{${t('workflow.nodes.startNode.label')}.question}}`

const form = {
model_id: '',
model_id_type: 'custom',
model_id_reference: [],
system: '',
prompt: defaultPrompt,
negative_prompt: '',
Expand All @@ -261,6 +291,12 @@ const form = {
const form_data = computed({
get: () => {
if (props.nodeModel.properties.node_data) {
if (!props.nodeModel.properties.node_data.model_id_type) {
set(props.nodeModel.properties.node_data, 'model_id_type', 'custom')
}
if (!props.nodeModel.properties.node_data.model_id_reference) {
set(props.nodeModel.properties.node_data, 'model_id_reference', [])
}
return props.nodeModel.properties.node_data
} else {
set(props.nodeModel.properties, 'node_data', form)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code contains several improvements and optimizations:

  1. Refactor for Model Type Management: The model selection logic has been split into a separate template block to improve readability.

  2. Optimize Input Fields: Removed duplicate placeholder text for the model_id input field, as it matches the overall context already indicated.

  3. Add Option for Reference Models: A dropdown menu is added to allow users to reference other models instead of selecting them individually from the list.

  4. Simplify Validation Logic: Used Promise.all() to handle both nodeCascaderRef and aiChatNodeFormRef, providing clearer error handling.

  5. Remove Unused Import Statement: Removed an unnecessary import statement related to loadSharedApi.

  6. Improved Code Formatting: Minor adjustments have been made to improve formatting consistency across sections of the code.

These changes enhance the user experience, provide more flexibility in model selection, and simplify the validation process.

Expand Down
Loading
Loading