Skip to content

Commit 1435665

Browse files
committed
fix main solver pipeline
1 parent 7a6eba3 commit 1435665

5 files changed

Lines changed: 60 additions & 12 deletions

File tree

kag/solver/executor/retriever/kag_hybrid_retrieval_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def do_retrieval(self, task_query, tag_id, task, context: Context, **kwargs) ->
154154
spos = list(set(spos)) # Deduplicate
155155

156156
# Add report line if there are any SPOs
157-
if spos:
157+
if reporter and spos:
158158
reporter.add_report_line(
159159
tag_id,
160160
f"end_sub_kag_retriever_{output.retriever_method}",

kag/solver/executor/retriever/kag_model_hybrid_executor.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,11 @@ def do_main(self, task_query, tag_id, task, context, **kwargs):
104104

105105
break
106106
else:
107-
messages.append({
108-
"role": "assistant",
109-
"content": subquestion_response,
110-
})
111107
if "<search>" in subquestion_response:
108+
messages.append({
109+
"role": "assistant",
110+
"content": subquestion_response,
111+
})
112112
search = search_plan_extraction(subquestion_response)
113113
# 有时候会缺失</search>训练时需要优化<search>内容,不需要直接换行
114114
if len(search) == 0:
@@ -117,6 +117,9 @@ def do_main(self, task_query, tag_id, task, context, **kwargs):
117117
try:
118118
sub_queries, logic_forms = parse_logic_form_with_str(search)
119119
logic_forms = self.logic_node_parser.parse_logic_form_set(logic_forms, sub_queries, task_query)
120+
if not logic_forms:
121+
logic_node.sub_query = search
122+
logic_forms = [logic_node]
120123
except Exception as e:
121124
logger.warning(f"kag model think can not extra lf from {search} {e}")
122125
logic_node.sub_query = search
@@ -142,7 +145,7 @@ def do_main(self, task_query, tag_id, task, context, **kwargs):
142145
"query": target_query,
143146
"logic_form_node": logic_forms[0]
144147
})
145-
retriever_output = self.do_retrieval(task_query=target_query, tag_id=tag_id, task=cur_task,
148+
retriever_output = self.do_retrieval(task_query=target_query, tag_id=cur_turn_tag_name, task=cur_task,
146149
context=context, **kwargs)
147150

148151
recall_information_list = []
@@ -158,7 +161,7 @@ def do_main(self, task_query, tag_id, task, context, **kwargs):
158161
"content": recall_str,
159162
})
160163
except Exception as e:
161-
logger.error(f"kag flow exception! {e}", exc_info=True)
164+
logger.error(f"kag flow exception! {e} search={search}", exc_info=True)
162165
self.report_content(
163166
reporter,
164167
cur_turn_tag_name,
@@ -167,6 +170,11 @@ def do_main(self, task_query, tag_id, task, context, **kwargs):
167170
"INIT",
168171
step=task.name,
169172
)
173+
else:
174+
messages.append({
175+
"role": "assistant",
176+
"content": subquestion_response,
177+
})
170178

171179
context.kwargs["messages"] = messages
172180
return retriever_output

kag/solver/executor/retriever/local_knowledge_base/kag_retriever/kag_component/kg_cs/lf_kg_retriever_template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from kag.common.conf import KAGConstants, KAGConfigAccessor
66
from kag.interface import LLMClient
77
from kag.interface.solver.base_model import SPOEntity, LogicNode
8-
from kag.interface.solver.reporter_abc import ReporterABC, DotRefresher
8+
from kag.interface.solver.reporter_abc import ReporterABC
99
from kag.interface.solver.model.one_hop_graph import KgGraph, EntityData
1010
from kag.common.parser.logic_node_parser import GetSPONode
1111

kag/solver/main_solver.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,7 @@ async def qa(task_id, query, project_id, host_addr, app_id, params={}):
312312
task_id=task_id,
313313
kb_project_ids=kb_project_ids,
314314
)
315-
316-
if answer:
317-
reporter.add_report_line("answer", "Final Answer", answer, "FINISH")
315+
reporter.add_report_line("answer", "Final Answer", answer, "FINISH")
318316

319317
except Exception as e:
320318
logger.warning(
@@ -371,6 +369,40 @@ def invoke(
371369
)
372370
return answer
373371

372+
async def ainvoke(
373+
self,
374+
project_id: int,
375+
task_id: int,
376+
query: str,
377+
session_id: str = "0",
378+
is_report=True,
379+
host_addr="http://127.0.0.1:8887",
380+
params=None,
381+
app_id="",
382+
):
383+
answer = None
384+
if params is None:
385+
params = {}
386+
try:
387+
answer = await qa(
388+
task_id=task_id,
389+
project_id=project_id,
390+
host_addr=host_addr,
391+
query=query,
392+
params=params,
393+
app_id=app_id,
394+
)
395+
logger.info(f"{query} answer={answer}")
396+
except Exception as e:
397+
import traceback
398+
399+
traceback.print_exc()
400+
logger.warning(
401+
f"An exception occurred while processing query: {query}. Error: {str(e)}",
402+
exc_info=True,
403+
)
404+
return answer
405+
374406

375407
if __name__ == "__main__":
376408
# init_kag_config(

kag/solver/reporter/open_spg_kag_model_reporter.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import logging
2+
import re
3+
24
from kag.common.conf import KAG_PROJECT_CONF
35
from kag.common.parser.logic_node_parser import extract_steps_and_actions
46

@@ -13,6 +15,12 @@ class SafeDict(dict):
1315
def __missing__(self, key):
1416
return ""
1517

18+
def remove_xml_tags(text):
19+
# 正则表达式匹配所有 XML 标签,例如:<tag> 或 </tag> 或 <tag attr="value">
20+
pattern = r'<[^>]+>'
21+
# 用空字符串替换所有匹配项
22+
clean_text = re.sub(pattern, '', text)
23+
return clean_text
1624

1725
def process_planning(think_str):
1826
result = []
@@ -67,7 +75,7 @@ def process_tag_template(text):
6775
clean_text += xml_tag_template[tag_info[0]][KAG_PROJECT_CONF.language].format_map(SafeDict({
6876
"content": content
6977
}))
70-
return clean_text
78+
return remove_xml_tags(clean_text)
7179
return text
7280

7381

0 commit comments

Comments
 (0)