Skip to content

Commit e1012d3

Browse files
royzhaozhuzhongshu123northmachinewanxingyu.wxycaszkgui
authored
feat(solver): support kag thinker (#640)
* feat(kag): update to v0.7 (#456) * add think cost * update csv scanner * add final rerank * add reasoner * add iterative planner * fix dpr search * fix dpr search * add reference data * move odps import * update requirement.txt * update 2wiki * add missing file * fix markdown reader * add iterative planning * update version * update runner * update 2wiki example * update bridge * merge solver and solver_new * add cur day * writer delete * update multi process * add missing files * fix report * add chunk retrieved executor * update try in stream runner result * add path * add math executor * update hotpotqa example * remove log * fix python coder solver * update hotpotqa example * fix python coder solver * update config * fix bad * add log * remove unused code * commit with task thought * move kag model to common * add default chat llm * fix * use static planner * support chunk graph node * add args * support naive rag * llm client support tool calls * add default async * add openai * fix result * fix markdown reader * fix thinker * update asyncio interface * feat(solver): add mcp support (#444) * 上传mcp client相关代码 * 1、完成一套mcp client的调用,从pipeline到planner、executor 2、允许json中传入多个mcp_server,通过大模型进行调用并选择 3、调通baidu_map_mcp的使用 * 1、schema * bugfix:删减冗余代码 --------- Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com> * fix affairqa after solver refactor * fix affairqa after solver refactor * fix readme * add params * update version * update mcp executor * update mcp executor * solver add mcp executor * add missing file * add mpc executor * add executor * x * update * fix requirement * fix main llm config * fix solver * bugfix:修复invoke函数调用逻辑 * chg eva * update example * add kag layer * add step task * support dot refresh * support dot refresh * support dot refresh * support dot refresh * add retrieved num * add retrieved num * add pipelineconf * update ppr * update musique prompts * update * add to_dict for BuilderComponentData * async build * add deduce prompt * add deduce prompt * add deduce prompt * fix reader * add deduce prompt * add page thinker report * modify prmpt * add step status * add self cognition * add self cognition * add memory graph storage * add now time * update memory config * add now time * chg graph loader * 添加prqa数据集和代码 * bugfix:prqa调用逻辑修复 * optimize:优化代码逻辑,生成答案规范化 * add retry py code * update memory graph * update memory graph * fix * fix ner * add with_out_refer generator prompt * fix * close ckpt * fix query * fix query * update version * add llm checker * add llm checker * 1、上传evalutor.py以及修改gold_answer.json格式 2、优化代码逻辑 3、修改README.md文件 * update exp * update exp * rerank support * add static rewrite query * recall more chunks * fix graph load * add static rewrite query * fix bugs * add finish check * add finish check * add finish check * add finish check * 1、上传evalutor.py的结果 2、优化代码逻辑,优化readme文件 * add lf retry * add memory graph api * fix reader api * add ner * add metrics * fix bug * remove ner * add reraise fo retry * add edge prop to memory graph * add memory graph * 1、评测数据集结果修正 2、优化evaluator.py代码 3、删除结果不存在而gold_answer中有答案的问题 * 删除评测结果文件 * fix knext host addr * async eva * add lf prompt * add lf prompt * add config * add retry * add unknown check * add rc result * add rc result * add rc result * add rc result * 依据kag pipeline格式修改代码逻辑并通过测试 * bugfix:删除冗余代码 * fix report prompt * bugfix:触发重试机制 * bugfix:中文符号错误 * fix rethinker prompt * update version to 0.6.2b78 * update version * 1、修改evaluator.py,通过大模型计算准确率,符合最新调用逻辑 2、修改prompt,让没有回答的结果重复测试 * update affairqa for evaluate * update affairqa for evaluate * bugfix:修正数据集 * bugfix:修正数据集 * bugfix:修正数据集 * fix name conflict * bugfix:删除错误问题 * bugfix:文件名命名错误导致evaluator失败 * update for affairqa eval * bugfix:修改代码保持evaluate逻辑一致 * x * update for affairqa readme * remove temp eval scripts * bugfix for math deduce * merge 0.6.2_dev * merge 0.6.2_dev * fix * update client addr * updated version * update for affairqa eval * evaUtils 支持中文 * fix affairqa eval: * remove unused example * update kag config * fix default value * update readme * fix init * 注释信息修改,并添加部分class说明 * update example config * Tc 0.7.0 (#459) * 提交affairQA 代码 * fix affairqa eval --------- Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com> * fix all examples * reformat --------- Co-authored-by: peilong <peilong.zpl@antgroup.com> Co-authored-by: 锦呈 <zhangxinhong.zxh@antgroup.com> Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com> Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com> * update chunk metadata * update chunk metadata * add debug reporter * update table text * add server * fix math executor * update api-key for openai vec * update * fix naive rag bug * format code * fix --------- Co-authored-by: zhuzhongshu123 <152354526+zhuzhongshu123@users.noreply.github.com> Co-authored-by: 锦呈 <zhangxinhong.zxh@antgroup.com> Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com> Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com>
1 parent 9b2d894 commit e1012d3

20 files changed

Lines changed: 696 additions & 175 deletions

KAG_VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.8.0
1+
0.8.0

kag/common/utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -463,9 +463,17 @@ def resolve_instance(
463463

464464

465465
def extract_tag_content(text):
466-
# 匹配<tag>和</tag>之间的内容,支持任意标签名
467-
matches = re.findall(r"<([^>]+)>(.*?)</\1>", text, flags=re.DOTALL)
468-
return [(tag, content.strip()) for tag, content in matches]
466+
pattern = r"<(\w+)\b[^>]*>(.*?)</\1>|<(\w+)\b[^>]*>([^<]*)|([^<]+)"
467+
results = []
468+
for match in re.finditer(pattern, text, re.DOTALL):
469+
tag1, content1, tag2, content2, raw_text = match.groups()
470+
if tag1:
471+
results.append((tag1, content1)) # 保留原始内容(含空格)
472+
elif tag2:
473+
results.append((tag2, content2)) # 保留原始内容(含空格)
474+
elif raw_text:
475+
results.append(("", raw_text)) # 保留原始空格
476+
return results
469477

470478

471479
def extract_specific_tag_content(text, tag):

kag/solver/executor/math/py_based_math_executor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,17 +131,19 @@ def invoke(self, query: str, task: Task, context: Context, **kwargs):
131131
)
132132

133133
parent_results = format_task_dep_context(task.parents)
134-
parent_results = "\n".join(parent_results)
134+
coder_content = context.kwargs.get("planner_thought", "") + "\n\n".join(
135+
parent_results
136+
)
135137

136-
parent_results += "\n\n" + contents
138+
coder_content += "\n\n" + contents
137139
tries = self.tries
138140
error = None
139141

140142
while tries > 0:
141143
tries -= 1
142144
rst, error, code = self.run_once(
143145
math_query,
144-
parent_results,
146+
coder_content,
145147
error,
146148
segment_name=tag_id,
147149
tag_name=f"{task_query}_code_generator",

kag/solver/executor/retriever/kag_hybrid_retrieval_executor.py

Lines changed: 105 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@
4242
logger = logging.getLogger()
4343

4444

45+
def _wrapped_invoke(retriever, task, context, segment_name, kwargs):
46+
start_time = time.time()
47+
output = retriever.invoke(
48+
task, context=context, segment_name=segment_name, **kwargs
49+
)
50+
elapsed_time = time.time() - start_time
51+
return output, elapsed_time
52+
53+
4554
@ExecutorABC.register("kag_hybrid_retrieval_executor")
4655
class KAGHybridRetrievalExecutor(ExecutorABC):
4756
def __init__(
@@ -76,6 +85,7 @@ def __init__(
7685
self.context_select_prompt = context_select_prompt or PromptABC.from_config(
7786
{"type": "context_select_prompt"}
7887
)
88+
self.with_llm_select = kwargs.get("with_llm_select", True)
7989

8090
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1))
8191
def context_select_call(self, variables):
@@ -152,22 +162,30 @@ def do_retrieval(
152162
"FINISH",
153163
component_name=retriever.name,
154164
)
155-
165+
# Record start time before submitting the task
166+
start_time = time.time()
156167
# Prepare function and submit to thread pool
157168
func = partial(
158-
retriever.invoke,
169+
_wrapped_invoke,
170+
retriever,
159171
task,
160-
context=context,
161-
segment_name=tag_id,
162-
**kwargs,
172+
context,
173+
tag_id,
174+
kwargs.copy(),
163175
)
164176
future = executor.submit(func)
177+
# Save future, retriever, and start_time together
165178
futures.append((future, retriever))
166179

167180
# Collect results from each future
168181
for future, retriever in futures:
169182
try:
170-
output = future.result() # Wait for result
183+
output, elapsed_time = future.result() # Wait for result
184+
185+
# Log the elapsed time for this retriever
186+
logger.info(
187+
f"Retriever {retriever.name} executed in {elapsed_time:.2f} seconds"
188+
)
171189
outputs.append(output)
172190

173191
# Log data report after successful execution
@@ -241,13 +259,18 @@ def do_summary(
241259
selected_rel = list(set(selected_rel))
242260
formatted_docs = [str(rel) for rel in selected_rel]
243261
if retrieved_data.chunks:
244-
try:
245-
selected_chunks = self.context_select(task_query, retrieved_data.chunks)
246-
except Exception as e:
247-
logger.warning(
248-
f"select context failed {e}, we use default top 10 to summary",
249-
exc_info=True,
250-
)
262+
if self.with_llm_select:
263+
try:
264+
selected_chunks = self.context_select(
265+
task_query, retrieved_data.chunks
266+
)
267+
except Exception as e:
268+
logger.warning(
269+
f"select context failed {e}, we use default top 10 to summary",
270+
exc_info=True,
271+
)
272+
selected_chunks = retrieved_data.chunks[:10]
273+
else:
251274
selected_chunks = retrieved_data.chunks[:10]
252275
for doc in selected_chunks:
253276
formatted_docs.append(f"{doc.content}")
@@ -280,69 +303,82 @@ def invoke(self, query, task, context: Context, **kwargs) -> RetrieverOutput:
280303
task_query = task.arguments["query"]
281304

282305
tag_id = f"{task_query}_begin_task"
283-
self.report_content(reporter, "thinker", tag_id, "", "FINISH", step=task.name)
306+
self.report_content(reporter, "thinker", tag_id, "", "INIT", step=task.name)
284307
try:
285-
retrieved_data = self.do_main(task_query, tag_id, task, context, **kwargs)
286-
except Exception as e:
287-
logger.warning(f"kag hybrid retrieval failed! {e}", exc_info=True)
288-
retrieved_data = RetrieverOutput(
289-
retriever_method=self.schema().get("name", ""), err_msg=str(e)
290-
)
308+
try:
309+
retrieved_data = self.do_main(
310+
task_query, tag_id, task, context, **kwargs
311+
)
312+
except Exception as e:
313+
logger.warning(f"kag hybrid retrieval failed! {e}", exc_info=True)
314+
retrieved_data = RetrieverOutput(
315+
retriever_method=self.schema().get("name", ""), err_msg=str(e)
316+
)
291317

292-
self.report_content(
293-
reporter,
294-
"reference",
295-
f"{task_query}_kag_retriever_result",
296-
retrieved_data,
297-
"FINISH",
298-
)
318+
self.report_content(
319+
reporter,
320+
"reference",
321+
f"{task_query}_kag_retriever_result",
322+
retrieved_data,
323+
"FINISH",
324+
)
299325

300-
retrieved_data.task = task
301-
logical_node = task.arguments.get("logic_form_node", None)
302-
if (
303-
logical_node
304-
and isinstance(logical_node, GetSPONode)
305-
and retrieved_data.summary
306-
):
307-
if isinstance(retrieved_data.summary, str):
308-
target_answer = retrieved_data.summary.split("Answer:")[-1].strip()
309-
s_entities = context.variables_graph.get_entity_by_alias(
310-
logical_node.s.alias_name
326+
retrieved_data.task = task
327+
logical_node = task.arguments.get("logic_form_node", None)
328+
if (
329+
logical_node
330+
and isinstance(logical_node, GetSPONode)
331+
and retrieved_data.summary
332+
):
333+
if isinstance(retrieved_data.summary, str):
334+
target_answer = retrieved_data.summary.split("Answer:")[-1].strip()
335+
s_entities = context.variables_graph.get_entity_by_alias(
336+
logical_node.s.alias_name
337+
)
338+
if (
339+
not s_entities
340+
and not logical_node.s.get_mention_name()
341+
and isinstance(logical_node.s, SPOEntity)
342+
):
343+
logical_node.s.entity_name = target_answer
344+
context.kwargs[logical_node.s.alias_name] = logical_node.s
345+
o_entities = context.variables_graph.get_entity_by_alias(
346+
logical_node.o.alias_name
347+
)
348+
if (
349+
not o_entities
350+
and not logical_node.o.get_mention_name()
351+
and isinstance(logical_node.o, SPOEntity)
352+
):
353+
logical_node.o.entity_name = target_answer
354+
context.kwargs[logical_node.o.alias_name] = logical_node.o
355+
356+
context.variables_graph.add_answered_alias(
357+
logical_node.s.alias_name.alias_name, retrieved_data.summary
311358
)
312-
if (
313-
not s_entities
314-
and not logical_node.s.get_mention_name()
315-
and isinstance(logical_node.s, SPOEntity)
316-
):
317-
logical_node.s.entity_name = target_answer
318-
context.kwargs[logical_node.s.alias_name] = logical_node.s
319-
o_entities = context.variables_graph.get_entity_by_alias(
320-
logical_node.o.alias_name
359+
context.variables_graph.add_answered_alias(
360+
logical_node.p.alias_name.alias_name, retrieved_data.summary
321361
)
322-
if (
323-
not o_entities
324-
and not logical_node.o.get_mention_name()
325-
and isinstance(logical_node.o, SPOEntity)
326-
):
327-
logical_node.o.entity_name = target_answer
328-
context.kwargs[logical_node.o.alias_name] = logical_node.o
329-
330-
context.variables_graph.add_answered_alias(
331-
logical_node.s.alias_name.alias_name, retrieved_data.summary
332-
)
333-
context.variables_graph.add_answered_alias(
334-
logical_node.p.alias_name.alias_name, retrieved_data.summary
362+
context.variables_graph.add_answered_alias(
363+
logical_node.o.alias_name.alias_name, retrieved_data.summary
364+
)
365+
366+
task.update_result(retrieved_data)
367+
logger.debug(
368+
f"kag hybrid retrieval {task_query} cost={time.time() - start_time}"
335369
)
336-
context.variables_graph.add_answered_alias(
337-
logical_node.o.alias_name.alias_name, retrieved_data.summary
370+
return retrieved_data
371+
finally:
372+
self.report_content(
373+
reporter,
374+
"thinker",
375+
tag_id,
376+
"",
377+
"FINISH",
378+
step=task.name,
379+
overwrite=False,
338380
)
339381

340-
task.update_result(retrieved_data)
341-
logger.debug(
342-
f"kag hybrid retrieval {task_query} cost={time.time() - start_time}"
343-
)
344-
return retrieved_data
345-
346382
def schema(self) -> dict:
347383
"""Function schema definition for OpenAI Function Calling
348384
@@ -403,7 +439,7 @@ def do_data_report(
403439
node_type=chunk.properties.get("__labels__"),
404440
)
405441
entity_prop = dict(chunk.properties) if chunk.properties else {}
406-
entity_prop["content"] = chunk.content
442+
entity_prop["content"] = f"{chunk.content[:10]}..."
407443
entity_prop["score"] = chunk.score
408444
entity.prop = Prop.from_dict(entity_prop, "Chunk", None)
409445
chunk_graph.append(entity)

kag/solver/main_solver.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,6 @@ def get_pipeline_conf(use_pipeline_name, config):
140140
raise RuntimeError("mcpServers not found in config.")
141141
default_solver_pipeline["executors"] = mcp_executors
142142

143-
# update KAG_CONFIG
144-
KAG_CONFIG.update_conf(default_pipeline_conf)
145143
return default_solver_pipeline
146144

147145

@@ -167,8 +165,11 @@ async def do_qa_pipeline(
167165
f"Knowledge base with id {kb_project_id} not found in qa_config['kb']"
168166
)
169167
continue
170-
171-
for index_name in matched_kb.get("index_list", []):
168+
index_list = matched_kb.get("index_list", [])
169+
if use_pipeline in ["default_pipeline"]:
170+
# we only use chunk index
171+
index_list = ["chunk_index"]
172+
for index_name in index_list:
172173
index_manager = KAGIndexManager.from_config(
173174
{
174175
"type": index_name,
@@ -339,7 +340,7 @@ class SolverMain:
339340
def invoke(
340341
self,
341342
project_id: int,
342-
task_id: int,
343+
task_id,
343344
query: str,
344345
session_id: str = "0",
345346
is_report=True,

kag/solver/pipelineconf/naive_rag.yaml

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,17 @@ pipeline_name: default_pipeline
33

44
#------------kag-solver configuration start----------------#
55

6-
7-
chunk_retrieved_executor: &chunk_retrieved_executor_conf
8-
type: chunk_retrieved_executor
9-
top_k: 10
10-
retriever:
11-
type: vector_chunk_retriever
12-
score_threshold: 0.65
13-
vectorize_model: "{vectorize_model}"
14-
6+
kag_retriever_executor: &kag_retriever_executor_conf
7+
type: kag_hybrid_retrieval_executor
8+
retrievers: "{retrievers}"
9+
merger:
10+
type: kag_merger
11+
enable_summary: false
1512

1613
solver_pipeline:
1714
type: naive_rag_pipeline
1815
executors:
19-
- *chunk_retrieved_executor_conf
16+
- *kag_retriever_executor_conf
2017
generator:
2118
type: llm_index_generator
2219
llm_client: "{chat_llm}"

kag/solver/planner/kag_model_planner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ async def ainvoke(self, query, **kwargs) -> List[Task]:
186186
.replace("</answer>", "")
187187
.strip()
188188
)
189+
context.kwargs["planner_thought"] = logic_form_response
189190

190191
sub_queries, logic_forms = parse_logic_form_with_str(logic_form_str)
191192
logic_forms = self.logic_node_parser.parse_logic_form_set(

0 commit comments

Comments
 (0)