Skip to content

Commit 460ac0c

Browse files
committed
feat: Add many nodes support reference model
1 parent 7810cd2 commit 460ac0c

File tree

15 files changed

+295
-93
lines changed

15 files changed

+295
-93
lines changed

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
187187
self.context['model_setting'] = model_setting
188188
workspace_id = self.workflow_manage.get_body().get('workspace_id')
189189
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
190-
**model_params_setting)
190+
**(model_params_setting or {}))
191191
history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
192192
self.runtime_node_id)
193193
self.context['history_message'] = [{'content': message.content, 'role': message.type} for message in

apps/application/flow/step_node/image_to_video_step_node/i_image_to_video_node.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010

1111

1212
class ImageToVideoNodeSerializer(serializers.Serializer):
13-
model_id = serializers.CharField(required=True, label=_("Model id"))
13+
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id"))
14+
model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type"))
15+
model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True,
16+
label=_("Reference Field"))
1417

1518
prompt = serializers.CharField(required=True, label=_("Prompt word (positive)"))
1619

@@ -69,5 +72,6 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t
6972
model_params_setting,
7073
chat_record_id,
7174
first_frame_url, last_frame_url,
75+
model_id_type=None, model_id_reference=None,
7276
**kwargs) -> NodeResult:
7377
pass

apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,21 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t
2929
model_params_setting,
3030
chat_record_id,
3131
first_frame_url, last_frame_url=None,
32+
model_id_type=None, model_id_reference=None,
3233
**kwargs) -> NodeResult:
34+
# 处理引用类型
35+
if model_id_type == 'reference' and model_id_reference:
36+
reference_data = self.workflow_manage.get_reference_field(
37+
model_id_reference[0],
38+
model_id_reference[1:],
39+
)
40+
if reference_data and isinstance(reference_data, dict):
41+
model_id = reference_data.get('model_id', model_id)
42+
model_params_setting = reference_data.get('model_params_setting')
43+
3344
workspace_id = self.workflow_manage.get_body().get('workspace_id')
3445
ttv_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
35-
**model_params_setting)
46+
**(model_params_setting or {}))
3647
history_message = self.get_history_message(history_chat_record, dialogue_number)
3748
self.context['history_message'] = history_message
3849
question = self.generate_prompt_question(prompt)

apps/application/flow/step_node/intent_node/impl/base_intent_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def execute(self, model_id, dialogue_number, history_chat_record, user_input, br
7171
# 获取模型实例
7272
workspace_id = self.workflow_manage.get_body().get('workspace_id')
7373
chat_model = get_model_instance_by_model_workspace_id(
74-
model_id, workspace_id, **model_params_setting
74+
model_id, workspace_id, **(model_params_setting or {})
7575
)
7676

7777
# 获取历史对话

apps/application/flow/step_node/parameter_extraction_node/i_parameter_extraction_node.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ class VariableSplittingNodeParamsSerializer(serializers.Serializer):
1919
model_params_setting = serializers.DictField(required=False,
2020
label=_("Model parameter settings"))
2121

22-
model_id = serializers.CharField(required=True, label=_("Model id"))
22+
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id"))
23+
model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type"))
24+
model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True,
25+
label=_("Reference Field"))
2326

2427

2528
class IParameterExtractionNode(INode):
@@ -31,12 +34,25 @@ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
3134
return VariableSplittingNodeParamsSerializer
3235

3336
def _run(self):
37+
model_id_type = self.node_params_serializer.data.get('model_id_type')
38+
model_id_reference = self.node_params_serializer.data.get('model_id_reference')
39+
model_id = self.node_params_serializer.data.get('model_id')
40+
model_params_setting = self.node_params_serializer.data.get('model_params_setting')
41+
# 处理引用类型
42+
if model_id_type == 'reference' and model_id_reference:
43+
reference_data = self.workflow_manage.get_reference_field(
44+
model_id_reference[0],
45+
model_id_reference[1:],
46+
)
47+
if reference_data and isinstance(reference_data, dict):
48+
model_id = reference_data.get('model_id', model_id)
49+
model_params_setting = reference_data.get('model_params_setting')
50+
3451
input_variable = self.workflow_manage.get_reference_field(
3552
self.node_params_serializer.data.get('input_variable')[0],
3653
self.node_params_serializer.data.get('input_variable')[1:])
3754
return self.execute(input_variable, self.node_params_serializer.data['variable_list'],
38-
self.node_params_serializer.data['model_params_setting'],
39-
self.node_params_serializer.data['model_id'])
55+
model_params_setting, model_id)
4056

4157
def execute(self, input_variable, variable_list, model_params_setting, model_id, **kwargs) -> NodeResult:
4258
pass

apps/application/flow/step_node/parameter_extraction_node/impl/base_parameter_extraction_node.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,12 @@ def save_context(self, details, workflow_manage):
9494
def execute(self, input_variable, variable_list, model_params_setting, model_id, **kwargs) -> NodeResult:
9595
input_variable = str(input_variable)
9696
self.context['request'] = input_variable
97-
if model_params_setting is None:
97+
if model_params_setting is None and model_id:
9898
model_params_setting = get_default_model_params_setting(model_id)
9999
workspace_id = self.workflow_manage.get_body().get('workspace_id')
100100
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
101-
**model_params_setting)
101+
**(model_params_setting or {}))
102+
102103
content = generate_content(input_variable, variable_list)
103104
response = chat_model.invoke([HumanMessage(content=content)])
104105
result = json_loads(response.content, variable_list)

apps/application/flow/step_node/question_node/i_question_node.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616

1717

1818
class QuestionNodeSerializer(serializers.Serializer):
19-
model_id = serializers.CharField(required=True, label=_("Model id"))
19+
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id"))
20+
model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type"))
21+
model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True,
22+
label=_("Reference Field"))
2023
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
2124
label=_("Role Setting"))
2225
prompt = serializers.CharField(required=True, label=_("Prompt word"))
@@ -42,6 +45,6 @@ def _run(self):
4245
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
4346

4447
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
45-
model_params_setting=None,
48+
model_params_setting=None, model_id_type=None, model_id_reference=None,
4649
**kwargs) -> NodeResult:
4750
pass

apps/application/flow/step_node/question_node/impl/base_question_node.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,23 @@ def save_context(self, details, workflow_manage):
8383
self.answer_text = details.get('answer')
8484

8585
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
86-
model_params_setting=None,
86+
model_params_setting=None, model_id_type=None, model_id_reference=None,
8787
**kwargs) -> NodeResult:
88-
if model_params_setting is None:
88+
# 处理引用类型
89+
if model_id_type == 'reference' and model_id_reference:
90+
reference_data = self.workflow_manage.get_reference_field(
91+
model_id_reference[0],
92+
model_id_reference[1:],
93+
)
94+
if reference_data and isinstance(reference_data, dict):
95+
model_id = reference_data.get('model_id', model_id)
96+
model_params_setting = reference_data.get('model_params_setting')
97+
98+
if model_params_setting is None and model_id:
8999
model_params_setting = get_default_model_params_setting(model_id)
90100
workspace_id = self.workflow_manage.get_body().get('workspace_id')
91101
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
92-
**model_params_setting)
102+
**(model_params_setting or {}))
93103
history_message = self.get_history_message(history_chat_record, dialogue_number)
94104
self.context['history_message'] = history_message
95105
question = self.generate_prompt_question(prompt)

apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def execute(self, tts_model_id,
6969
self.context['content'] = chunk
7070
workspace_id = self.workflow_manage.get_body().get('workspace_id')
7171
model = get_model_instance_by_model_workspace_id(
72-
tts_model_id, workspace_id, **model_params_setting)
72+
tts_model_id, workspace_id, **(model_params_setting or {}))
7373

7474
audio_byte = model.text_to_speech(chunk)
7575

apps/application/flow/step_node/video_understand_step_node/i_video_understand_node.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010

1111

1212
class VideoUnderstandNodeSerializer(serializers.Serializer):
13-
model_id = serializers.CharField(required=True, label=_("Model id"))
13+
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id"))
14+
model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type"))
15+
model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True,
16+
label=_("Reference Field"))
1417
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
1518
label=_("Role Setting"))
1619
prompt = serializers.CharField(required=True, label=_("Prompt word"))
@@ -52,5 +55,6 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist
5255
model_params_setting,
5356
chat_record_id,
5457
video,
58+
model_id_type=None, model_id_reference=None,
5559
**kwargs) -> NodeResult:
5660
pass

0 commit comments

Comments
 (0)