Skip to content

Commit d5874e7

Browse files
意图识别:支持自定义。
1 parent ab9c93a commit d5874e7

36 files changed

Lines changed: 249 additions & 83 deletions

File tree

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,14 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
167167

168168
if model_params_setting is None:
169169
model_params_setting = get_default_model_params_setting(model_id)
170+
170171
if model_setting is None:
171172
model_setting = {'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
172173
'reasoning_content_start': '<think>'}
173174
self.context['model_setting'] = model_setting
174175
workspace_id = self.workflow_manage.get_body().get('workspace_id')
175176
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
176-
**model_params_setting)
177+
**(model_params_setting or {}))
177178
history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
178179
self.runtime_node_id)
179180
self.context['history_message'] = [{'content': message.content, 'role': message.type} for message in

apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t
2828
**kwargs) -> NodeResult:
2929
workspace_id = self.workflow_manage.get_body().get('workspace_id')
3030
tti_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
31-
**model_params_setting)
31+
**(model_params_setting or {}))
3232
history_message = self.get_history_message(history_chat_record, dialogue_number)
3333
self.context['history_message'] = history_message
3434
question = self.generate_prompt_question(prompt)

apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t
3232
**kwargs) -> NodeResult:
3333
workspace_id = self.workflow_manage.get_body().get('workspace_id')
3434
ttv_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
35-
**model_params_setting)
35+
**(model_params_setting or {}))
3636
history_message = self.get_history_message(history_chat_record, dialogue_number)
3737
self.context['history_message'] = history_message
3838
question = self.generate_prompt_question(prompt)

apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _run(self):
4343
if [WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP].__contains__(
4444
self.workflow_manage.flow.workflow_mode):
4545
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data,
46-
**{'history_chat_record': [], 'stream': True, 'chat_record_id': None})
46+
**{'history_chat_record': [], 'stream': True, 'chat_record_id': None})
4747
else:
4848
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
4949

apps/application/flow/step_node/intent_node/impl/base_intent_node.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from application.flow.step_node.intent_node.i_intent_node import IIntentNode
1313
from models_provider.models import Model
1414
from models_provider.tools import get_model_instance_by_model_workspace_id, get_model_credential
15-
from .prompt_template import PROMPT_TEMPLATE
15+
from .prompt_template import DEFAULT_PROMPT_TEMPLATE
1616

1717

1818
def get_default_model_params_setting(model_id):
@@ -52,7 +52,7 @@ def save_context(self, details, workflow_manage):
5252
self.context['branch_id'] = details.get('branch_id')
5353
self.context['category'] = details.get('category')
5454

55-
def execute(self, model_id, dialogue_number, history_chat_record, user_input, branch, output_reason,
55+
def execute(self, model_id, prompt_template, dialogue_number, history_chat_record, user_input, branch, output_reason,
5656
model_params_setting=None, **kwargs) -> NodeResult:
5757

5858
# 设置默认模型参数
@@ -62,7 +62,7 @@ def execute(self, model_id, dialogue_number, history_chat_record, user_input, br
6262
# 获取模型实例
6363
workspace_id = self.workflow_manage.get_body().get('workspace_id')
6464
chat_model = get_model_instance_by_model_workspace_id(
65-
model_id, workspace_id, **model_params_setting
65+
model_id, workspace_id, **(model_params_setting or {})
6666
)
6767

6868
# 获取历史对话
@@ -73,7 +73,8 @@ def execute(self, model_id, dialogue_number, history_chat_record, user_input, br
7373
self.context['user_input'] = user_input
7474

7575
# 构建分类提示词
76-
prompt = self.build_classification_prompt(user_input, branch, output_reason)
76+
prompt_template = self.workflow_manage.generate_prompt(prompt_template) if prompt_template else None
77+
prompt = self.build_classification_prompt(prompt_template, user_input, branch, output_reason)
7778
self.context['system'] = prompt
7879

7980
# 生成消息列表
@@ -95,7 +96,7 @@ def execute(self, model_id, dialogue_number, history_chat_record, user_input, br
9596
'history_message': history_message,
9697
'user_input': user_input,
9798
'branch_id': matched_branch['id'],
98-
'reason': self.parse_result_reason(r.content) if output_reason is True else '',
99+
'reason': self.parse_result_reason(r.content) if output_reason is not False else '',
99100
'category': matched_branch.get('content', matched_branch['id'])
100101
}, {}, _write_context=write_context)
101102

@@ -125,7 +126,7 @@ def get_history_message(history_chat_record, dialogue_number):
125126
message.content = re.sub('<form_rander>[\d\D]*?<\/form_rander>', '', message.content)
126127
return history_message
127128

128-
def build_classification_prompt(self, user_input: str, branch: List[Dict], output_reason: bool) -> str:
129+
def build_classification_prompt(self, prompt_template: str, user_input: str, branch: List[Dict], output_reason: bool) -> str:
129130
"""构建分类提示词"""
130131

131132
classification_list = []
@@ -148,10 +149,10 @@ def build_classification_prompt(self, user_input: str, branch: List[Dict], outpu
148149
classification_id += 1
149150

150151
# 构建输出JSON结构
151-
output_reason = ',\n"reason": ""' if output_reason is True else ''
152-
output_json = f'{{\n"classificationId": 0{output_reason}\n}}'
152+
reason_field = ',\n"reason": ""' if output_reason is not False else ''
153+
output_json = f'{{\n"classificationId": 0{reason_field}\n}}'
153154

154-
return PROMPT_TEMPLATE.format(
155+
return (prompt_template or DEFAULT_PROMPT_TEMPLATE).format(
155156
classification_list=json.dumps(classification_list, ensure_ascii=False),
156157
user_input=user_input,
157158
output_json=output_json
@@ -179,8 +180,17 @@ def get_branch_by_id(category_id: int):
179180
return None
180181

181182
try:
182-
result_json = json.loads(result)
183-
classification_id = result_json.get('classificationId')
183+
classification_id = None
184+
185+
# 如果长度小于5,先尝试解析为数字(增加自由度,在自定义提示词模板时,可提示大模型只输出意图分类的ID值)
186+
if len(result) < 5:
187+
classification_id = self.to_int(result)
188+
189+
# 尝试解析为 JSON
190+
if classification_id is None:
191+
result_json = json.loads(result)
192+
classification_id = result_json.get('classificationId')
193+
184194
# 如果是 0 ,返回其他分支
185195
matched_branch = get_branch_by_id(classification_id)
186196
if matched_branch:
@@ -220,6 +230,12 @@ def parse_result_reason(self, result: str):
220230

221231
return ''
222232

233+
def to_int(self, str):
234+
try:
235+
return int(str)
236+
except ValueError:
237+
return None
238+
223239
def find_other_branch(self, branch: List[Dict]) -> Dict[str, Any] | None:
224240
"""查找其他分支"""
225241
for b in branch:

apps/application/flow/step_node/intent_node/impl/prompt_template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22

3-
PROMPT_TEMPLATE = """# Role
3+
DEFAULT_PROMPT_TEMPLATE = """# Role
44
You are an intention classification expert, good at being able to judge which classification the user's input belongs to.
55
66
## Skills

apps/application/flow/step_node/parameter_extraction_node/impl/base_parameter_extraction_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def execute(self, input_variable, variable_list, model_params_setting, model_id,
9898
model_params_setting = get_default_model_params_setting(model_id)
9999
workspace_id = self.workflow_manage.get_body().get('workspace_id')
100100
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
101-
**model_params_setting)
101+
**(model_params_setting or {}))
102+
102103
content = generate_content(input_variable, variable_list)
103104
response = chat_model.invoke([HumanMessage(content=content)])
104105
result = json_loads(response.content, variable_list)

apps/application/flow/step_node/question_node/impl/base_question_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
9090
model_params_setting = get_default_model_params_setting(model_id)
9191
workspace_id = self.workflow_manage.get_body().get('workspace_id')
9292
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
93-
**model_params_setting)
93+
**(model_params_setting or {}))
9494
history_message = self.get_history_message(history_chat_record, dialogue_number)
9595
self.context['history_message'] = history_message
9696
question = self.generate_prompt_question(prompt)

apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def save_context(self, details, workflow_manage):
2323

2424
def execute(self, stt_model_id, audio, model_params_setting=None, **kwargs) -> NodeResult:
2525
workspace_id = self.workflow_manage.get_body().get('workspace_id')
26-
stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id, **model_params_setting)
26+
stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id, **(model_params_setting or {}))
2727
audio_list = audio
2828
self.context['audio_list'] = audio
2929

apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def execute(self, tts_model_id,
6060
self.context['content'] = chunk
6161
workspace_id = self.workflow_manage.get_body().get('workspace_id')
6262
model = get_model_instance_by_model_workspace_id(
63-
tts_model_id, workspace_id, **model_params_setting)
63+
tts_model_id, workspace_id, **(model_params_setting or {}))
6464

6565
audio_byte = model.text_to_speech(chunk)
6666

0 commit comments

Comments
 (0)