Skip to content

Commit 60bf980

Browse files
committed
fix main solver pipeline
1 parent 91ee1ed commit 60bf980

5 files changed

Lines changed: 50 additions & 71 deletions

File tree

kag/common/conf.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,16 @@ def update_conf(self, configs: dict):
208208
KAG_QA_TASK_CONFIG stores per-task configuration and should be cleaned up after use.
209209
"""
210210
KAG_QA_TASK_CONFIG = knext.common.cache.LinkCache(maxsize=100, ttl=300)
211-
KAG_QA_TASK_CONFIG_LOCK = threading.Lock()
212-
213211

214212
class KAGConfigAccessor:
213+
@staticmethod
214+
def get_default_config():
215+
if KAG_CONFIG.global_config.project_id:
216+
return KAG_CONFIG
217+
for k in KAG_QA_TASK_CONFIG.cache.keys():
218+
return KAG_QA_TASK_CONFIG.get(k)
219+
return KAG_CONFIG
220+
215221
@staticmethod
216222
def get_config(task_with_kb_id=None) -> KAGConfigMgr:
217223
"""
@@ -224,9 +230,8 @@ def get_config(task_with_kb_id=None) -> KAGConfigMgr:
224230
:return: Corresponding configuration object
225231
"""
226232
if task_with_kb_id is not None:
227-
with KAG_QA_TASK_CONFIG_LOCK:
228-
return KAG_QA_TASK_CONFIG.get(task_with_kb_id)
229-
return KAG_CONFIG
233+
return KAG_QA_TASK_CONFIG.get(task_with_kb_id)
234+
return KAGConfigAccessor.get_default_config()
230235

231236
@staticmethod
232237
def set_task_config(task_with_kb_id, config: KAGConfigMgr):
@@ -236,8 +241,7 @@ def set_task_config(task_with_kb_id, config: KAGConfigMgr):
236241
:param task_with_kb_id: Task ID
237242
:param config: Configuration object to store
238243
"""
239-
with KAG_QA_TASK_CONFIG_LOCK:
240-
KAG_QA_TASK_CONFIG.put(task_with_kb_id, config)
244+
KAG_QA_TASK_CONFIG.put(task_with_kb_id, config)
241245

242246

243247
def init_env(config_file: str = None):

kag/solver/executor/retriever/local_knowledge_base/chunk_retrieved_executor.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
from typing import Any, Optional
44

5-
from kag.interface import ExecutorABC, ToolABC
5+
from kag.interface import ExecutorABC, RetrieverABC
66
from kag.interface.solver.reporter_abc import ReporterABC
77
from kag.solver.executor.retriever.local_knowledge_base.kag_retriever.kag_hybrid_executor import (
88
KAGRetrievedResponse,
@@ -16,7 +16,7 @@
1616

1717
@ExecutorABC.register("chunk_retrieved_executor")
1818
class ChunkRetrievedExecutor(ExecutorABC):
19-
def __init__(self, top_k, retriever: ToolABC, **kwargs):
19+
def __init__(self, top_k, retriever: RetrieverABC, **kwargs):
2020
super().__init__(**kwargs)
2121
self.retriever = retriever
2222
self.top_k = top_k
@@ -43,23 +43,11 @@ def invoke(self, query: str, task: Any, context: dict, **kwargs):
4343
"FINISH",
4444
overwrite=False,
4545
)
46-
retrieved_result = self.retriever.invoke(query=task_query, top_k=self.top_k)
46+
retrieved_result = self.retriever.invoke(task, context=context, **kwargs)
4747

4848
# Log the retrieved results
4949
logger.debug(f"Retrieved results: {retrieved_result}")
5050

51-
chunk_datas = []
52-
for k, v in retrieved_result.items():
53-
chunk_datas.append(
54-
ChunkData(
55-
content=v["content"],
56-
title=v["name"],
57-
chunk_id=k,
58-
score=v["score"],
59-
properties=v,
60-
)
61-
)
62-
kag_response.chunk_datas = chunk_datas
6351
self.report_content(
6452
reporter,
6553
"reference",
@@ -72,7 +60,7 @@ def invoke(self, query: str, task: Any, context: dict, **kwargs):
7260
reporter,
7361
f"{task_query}_begin_kag_retriever",
7462
f"{task_query}_end_kag_retriever",
75-
f"{len(chunk_datas)}",
63+
f"{len(retrieved_result.chunks)}",
7664
"FINISH",
7765
)
7866

@@ -81,8 +69,7 @@ def invoke(self, query: str, task: Any, context: dict, **kwargs):
8169
logger.info(
8270
f"Finished retrieval process for query: {task_query}. Duration: {end_time - start_time} bytes"
8371
)
84-
kag_response.summary = "retrieved by local knowledgebase"
85-
store_results(task, kag_response)
72+
task.update_result(retrieved_result)
8673

8774
def schema(self) -> dict:
8875
"""Function schema definition for OpenAI Function Calling

kag/solver/main_solver.py

Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -149,32 +149,6 @@ def is_chinese(text):
149149
chinese_pattern = re.compile(r"[\u4e00-\u9fff]+")
150150
return bool(chinese_pattern.search(text))
151151

152-
153-
async def do_index_pipeline(query, qa_config, reporter):
154-
if "chat" not in qa_config or "index_list" not in qa_config["chat"]:
155-
raise RuntimeError("chat or index_list not found in qa_config.")
156-
index_names = qa_config.get("chat", {}).get("index_list", [])
157-
retriever_configs = []
158-
for index_name in index_names:
159-
try:
160-
index_manager = KAGIndexManager.from_config(
161-
{
162-
"type": index_name,
163-
"llm_config": qa_config.get("llm", {}),
164-
"vectorize_model_config": qa_config.get("vectorize_model", {}),
165-
}
166-
)
167-
retriever_configs += index_manager.build_retriever_config(
168-
qa_config.get("llm", {}), qa_config.get("vectorize_model", {})
169-
)
170-
except Exception as e:
171-
raise RuntimeError(f"not found index {index_name}")
172-
qa_config["retrievers"] = retriever_configs
173-
pipeline_config = get_pipeline_conf("index_pipeline", qa_config)
174-
pipeline = SolverPipelineABC.from_config(pipeline_config)
175-
return await pipeline.ainvoke(query, reporter=reporter)
176-
177-
178152
async def do_qa_pipeline(
179153
use_pipeline, query, qa_config, reporter, task_id, kb_project_ids
180154
):
@@ -219,10 +193,12 @@ async def do_qa_pipeline(
219193
custom_pipeline_conf = copy.deepcopy(qa_config.get(use_pipeline, None))
220194
else:
221195
custom_pipeline_conf = copy.deepcopy(qa_config.get("solver_pipeline", None))
222-
223-
self_cognition_conf = get_pipeline_conf("self_cognition_pipeline", qa_config)
224-
self_cognition_pipeline = SolverPipelineABC.from_config(self_cognition_conf)
225-
self_cognition_res = await self_cognition_pipeline.ainvoke(query, reporter=reporter)
196+
if use_pipeline not in ["index_pipeline"]:
197+
self_cognition_conf = get_pipeline_conf("self_cognition_pipeline", qa_config)
198+
self_cognition_pipeline = SolverPipelineABC.from_config(self_cognition_conf)
199+
self_cognition_res = await self_cognition_pipeline.ainvoke(query, reporter=reporter)
200+
else:
201+
self_cognition_res = False
226202
if not self_cognition_res:
227203
if custom_pipeline_conf:
228204
pipeline_config = custom_pipeline_conf
@@ -262,7 +238,8 @@ async def qa(task_id, query, project_id, host_addr, app_id, params={}):
262238

263239
kb_configs = {}
264240
kb_project_ids = []
265-
241+
vectorize_model = {}
242+
global_index_set = main_config.get("chat", {}).get("index_list", [])
266243
if isinstance(main_config.get("kb"), list):
267244
kbs = main_config["kb"]
268245
for kb in kbs:
@@ -293,12 +270,24 @@ async def qa(task_id, query, project_id, host_addr, app_id, params={}):
293270
kb_conf.update_conf({"llm": main_config["llm"]})
294271
if "vectorizer" in kb:
295272
kb_conf.update_conf({"vectorize_model": kb["vectorizer"]})
296-
273+
vectorize_model = kb["vectorizer"]
274+
if "index_list" not in kb and global_index_set:
275+
kb["index_list"] = global_index_set
297276
KAGConfigAccessor.set_task_config(kb_task_project_id, kb_conf)
298277
kb_configs[kb_project_id] = (kb_task_project_id, kb_conf)
299-
300278
except Exception as e:
301279
logger.error(f"KB配置初始化失败: {str(e)}", exc_info=True)
280+
if "vectorize_model" not in main_config.keys():
281+
main_config["vectorize_model"] = vectorize_model
282+
283+
if vectorize_model:
284+
KAG_CONFIG.update_conf({
285+
"vectorize_model": vectorize_model
286+
})
287+
if main_config["llm"]:
288+
KAG_CONFIG.update_conf({
289+
"llm": main_config["llm"]
290+
})
302291
reporter_map = {
303292
"kag_thinker_pipeline": "kag_open_spg_reporter"
304293
}
@@ -315,17 +304,15 @@ async def qa(task_id, query, project_id, host_addr, app_id, params={}):
315304

316305
try:
317306
await reporter.start()
318-
if use_pipeline == "index_pipeline":
319-
answer = await do_index_pipeline(query, main_config, reporter)
320-
else:
321-
answer = await do_qa_pipeline(
322-
use_pipeline,
323-
query,
324-
main_config,
325-
reporter,
326-
task_id=task_id,
327-
kb_project_ids=kb_project_ids,
328-
)
307+
answer = await do_qa_pipeline(
308+
use_pipeline,
309+
query,
310+
main_config,
311+
reporter,
312+
task_id=task_id,
313+
kb_project_ids=kb_project_ids,
314+
)
315+
329316
if answer:
330317
reporter.add_report_line("answer", "Final Answer", answer, "FINISH")
331318

kag/solver/pipeline/naive_rag_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ async def planning(self, query, context, **kwargs):
7171
"""
7272
tasks_dep = {}
7373
tasks_dep[0] = {
74-
"executor": "Retriever",
74+
"executor": "ChunkRetriever",
7575
"dependent_task_ids": [],
7676
"arguments": {"query": query},
7777
}

kag/solver/pipelineconf/naive_rag.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ chunk_retrieved_executor: &chunk_retrieved_executor_conf
99
top_k: 10
1010
retriever:
1111
type: vector_chunk_retriever
12+
score_threshold: 0.65
1213
vectorize_model: "{vectorize_model}"
1314

1415

@@ -17,7 +18,7 @@ solver_pipeline:
1718
executors:
1819
- *chunk_retrieved_executor_conf
1920
generator:
20-
type: llm_generator
21+
type: llm_index_generator
2122
llm_client: "{chat_llm}"
2223
generated_prompt:
2324
type: default_refer_generator_prompt

0 commit comments

Comments
 (0)