|
42 | 42 | logger = logging.getLogger() |
43 | 43 |
|
44 | 44 |
|
| 45 | +def _wrapped_invoke(retriever, task, context, segment_name, kwargs): |
| 46 | + start_time = time.time() |
| 47 | + output = retriever.invoke( |
| 48 | + task, context=context, segment_name=segment_name, **kwargs |
| 49 | + ) |
| 50 | + elapsed_time = time.time() - start_time |
| 51 | + return output, elapsed_time |
| 52 | + |
| 53 | + |
45 | 54 | @ExecutorABC.register("kag_hybrid_retrieval_executor") |
46 | 55 | class KAGHybridRetrievalExecutor(ExecutorABC): |
47 | 56 | def __init__( |
@@ -76,6 +85,7 @@ def __init__( |
76 | 85 | self.context_select_prompt = context_select_prompt or PromptABC.from_config( |
77 | 86 | {"type": "context_select_prompt"} |
78 | 87 | ) |
| 88 | + self.with_llm_select = kwargs.get("with_llm_select", True) |
79 | 89 |
|
80 | 90 | @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1)) |
81 | 91 | def context_select_call(self, variables): |
@@ -152,22 +162,30 @@ def do_retrieval( |
152 | 162 | "FINISH", |
153 | 163 | component_name=retriever.name, |
154 | 164 | ) |
155 | | - |
| 165 | + # Record start time before submitting the task |
| 166 | + start_time = time.time() |
156 | 167 | # Prepare function and submit to thread pool |
157 | 168 | func = partial( |
158 | | - retriever.invoke, |
| 169 | + _wrapped_invoke, |
| 170 | + retriever, |
159 | 171 | task, |
160 | | - context=context, |
161 | | - segment_name=tag_id, |
162 | | - **kwargs, |
| 172 | + context, |
| 173 | + tag_id, |
| 174 | + kwargs.copy(), |
163 | 175 | ) |
164 | 176 | future = executor.submit(func) |
| 177 | + # Save future, retriever, and start_time together |
165 | 178 | futures.append((future, retriever)) |
166 | 179 |
|
167 | 180 | # Collect results from each future |
168 | 181 | for future, retriever in futures: |
169 | 182 | try: |
170 | | - output = future.result() # Wait for result |
| 183 | + output, elapsed_time = future.result() # Wait for result |
| 184 | + |
| 185 | + # Log the elapsed time for this retriever |
| 186 | + logger.info( |
| 187 | + f"Retriever {retriever.name} executed in {elapsed_time:.2f} seconds" |
| 188 | + ) |
171 | 189 | outputs.append(output) |
172 | 190 |
|
173 | 191 | # Log data report after successful execution |
@@ -241,13 +259,18 @@ def do_summary( |
241 | 259 | selected_rel = list(set(selected_rel)) |
242 | 260 | formatted_docs = [str(rel) for rel in selected_rel] |
243 | 261 | if retrieved_data.chunks: |
244 | | - try: |
245 | | - selected_chunks = self.context_select(task_query, retrieved_data.chunks) |
246 | | - except Exception as e: |
247 | | - logger.warning( |
248 | | - f"select context failed {e}, we use default top 10 to summary", |
249 | | - exc_info=True, |
250 | | - ) |
| 262 | + if self.with_llm_select: |
| 263 | + try: |
| 264 | + selected_chunks = self.context_select( |
| 265 | + task_query, retrieved_data.chunks |
| 266 | + ) |
| 267 | + except Exception as e: |
| 268 | + logger.warning( |
| 269 | + f"select context failed {e}, we use default top 10 to summary", |
| 270 | + exc_info=True, |
| 271 | + ) |
| 272 | + selected_chunks = retrieved_data.chunks[:10] |
| 273 | + else: |
251 | 274 | selected_chunks = retrieved_data.chunks[:10] |
252 | 275 | for doc in selected_chunks: |
253 | 276 | formatted_docs.append(f"{doc.content}") |
@@ -280,69 +303,82 @@ def invoke(self, query, task, context: Context, **kwargs) -> RetrieverOutput: |
280 | 303 | task_query = task.arguments["query"] |
281 | 304 |
|
282 | 305 | tag_id = f"{task_query}_begin_task" |
283 | | - self.report_content(reporter, "thinker", tag_id, "", "FINISH", step=task.name) |
| 306 | + self.report_content(reporter, "thinker", tag_id, "", "INIT", step=task.name) |
284 | 307 | try: |
285 | | - retrieved_data = self.do_main(task_query, tag_id, task, context, **kwargs) |
286 | | - except Exception as e: |
287 | | - logger.warning(f"kag hybrid retrieval failed! {e}", exc_info=True) |
288 | | - retrieved_data = RetrieverOutput( |
289 | | - retriever_method=self.schema().get("name", ""), err_msg=str(e) |
290 | | - ) |
| 308 | + try: |
| 309 | + retrieved_data = self.do_main( |
| 310 | + task_query, tag_id, task, context, **kwargs |
| 311 | + ) |
| 312 | + except Exception as e: |
| 313 | + logger.warning(f"kag hybrid retrieval failed! {e}", exc_info=True) |
| 314 | + retrieved_data = RetrieverOutput( |
| 315 | + retriever_method=self.schema().get("name", ""), err_msg=str(e) |
| 316 | + ) |
291 | 317 |
|
292 | | - self.report_content( |
293 | | - reporter, |
294 | | - "reference", |
295 | | - f"{task_query}_kag_retriever_result", |
296 | | - retrieved_data, |
297 | | - "FINISH", |
298 | | - ) |
| 318 | + self.report_content( |
| 319 | + reporter, |
| 320 | + "reference", |
| 321 | + f"{task_query}_kag_retriever_result", |
| 322 | + retrieved_data, |
| 323 | + "FINISH", |
| 324 | + ) |
299 | 325 |
|
300 | | - retrieved_data.task = task |
301 | | - logical_node = task.arguments.get("logic_form_node", None) |
302 | | - if ( |
303 | | - logical_node |
304 | | - and isinstance(logical_node, GetSPONode) |
305 | | - and retrieved_data.summary |
306 | | - ): |
307 | | - if isinstance(retrieved_data.summary, str): |
308 | | - target_answer = retrieved_data.summary.split("Answer:")[-1].strip() |
309 | | - s_entities = context.variables_graph.get_entity_by_alias( |
310 | | - logical_node.s.alias_name |
| 326 | + retrieved_data.task = task |
| 327 | + logical_node = task.arguments.get("logic_form_node", None) |
| 328 | + if ( |
| 329 | + logical_node |
| 330 | + and isinstance(logical_node, GetSPONode) |
| 331 | + and retrieved_data.summary |
| 332 | + ): |
| 333 | + if isinstance(retrieved_data.summary, str): |
| 334 | + target_answer = retrieved_data.summary.split("Answer:")[-1].strip() |
| 335 | + s_entities = context.variables_graph.get_entity_by_alias( |
| 336 | + logical_node.s.alias_name |
| 337 | + ) |
| 338 | + if ( |
| 339 | + not s_entities |
| 340 | + and not logical_node.s.get_mention_name() |
| 341 | + and isinstance(logical_node.s, SPOEntity) |
| 342 | + ): |
| 343 | + logical_node.s.entity_name = target_answer |
| 344 | + context.kwargs[logical_node.s.alias_name] = logical_node.s |
| 345 | + o_entities = context.variables_graph.get_entity_by_alias( |
| 346 | + logical_node.o.alias_name |
| 347 | + ) |
| 348 | + if ( |
| 349 | + not o_entities |
| 350 | + and not logical_node.o.get_mention_name() |
| 351 | + and isinstance(logical_node.o, SPOEntity) |
| 352 | + ): |
| 353 | + logical_node.o.entity_name = target_answer |
| 354 | + context.kwargs[logical_node.o.alias_name] = logical_node.o |
| 355 | + |
| 356 | + context.variables_graph.add_answered_alias( |
| 357 | + logical_node.s.alias_name.alias_name, retrieved_data.summary |
311 | 358 | ) |
312 | | - if ( |
313 | | - not s_entities |
314 | | - and not logical_node.s.get_mention_name() |
315 | | - and isinstance(logical_node.s, SPOEntity) |
316 | | - ): |
317 | | - logical_node.s.entity_name = target_answer |
318 | | - context.kwargs[logical_node.s.alias_name] = logical_node.s |
319 | | - o_entities = context.variables_graph.get_entity_by_alias( |
320 | | - logical_node.o.alias_name |
| 359 | + context.variables_graph.add_answered_alias( |
| 360 | + logical_node.p.alias_name.alias_name, retrieved_data.summary |
321 | 361 | ) |
322 | | - if ( |
323 | | - not o_entities |
324 | | - and not logical_node.o.get_mention_name() |
325 | | - and isinstance(logical_node.o, SPOEntity) |
326 | | - ): |
327 | | - logical_node.o.entity_name = target_answer |
328 | | - context.kwargs[logical_node.o.alias_name] = logical_node.o |
329 | | - |
330 | | - context.variables_graph.add_answered_alias( |
331 | | - logical_node.s.alias_name.alias_name, retrieved_data.summary |
332 | | - ) |
333 | | - context.variables_graph.add_answered_alias( |
334 | | - logical_node.p.alias_name.alias_name, retrieved_data.summary |
| 362 | + context.variables_graph.add_answered_alias( |
| 363 | + logical_node.o.alias_name.alias_name, retrieved_data.summary |
| 364 | + ) |
| 365 | + |
| 366 | + task.update_result(retrieved_data) |
| 367 | + logger.debug( |
| 368 | + f"kag hybrid retrieval {task_query} cost={time.time() - start_time}" |
335 | 369 | ) |
336 | | - context.variables_graph.add_answered_alias( |
337 | | - logical_node.o.alias_name.alias_name, retrieved_data.summary |
| 370 | + return retrieved_data |
| 371 | + finally: |
| 372 | + self.report_content( |
| 373 | + reporter, |
| 374 | + "thinker", |
| 375 | + tag_id, |
| 376 | + "", |
| 377 | + "FINISH", |
| 378 | + step=task.name, |
| 379 | + overwrite=False, |
338 | 380 | ) |
339 | 381 |
|
340 | | - task.update_result(retrieved_data) |
341 | | - logger.debug( |
342 | | - f"kag hybrid retrieval {task_query} cost={time.time() - start_time}" |
343 | | - ) |
344 | | - return retrieved_data |
345 | | - |
346 | 382 | def schema(self) -> dict: |
347 | 383 | """Function schema definition for OpenAI Function Calling |
348 | 384 |
|
@@ -403,7 +439,7 @@ def do_data_report( |
403 | 439 | node_type=chunk.properties.get("__labels__"), |
404 | 440 | ) |
405 | 441 | entity_prop = dict(chunk.properties) if chunk.properties else {} |
406 | | - entity_prop["content"] = chunk.content |
| 442 | + entity_prop["content"] = f"{chunk.content[:10]}..." |
407 | 443 | entity_prop["score"] = chunk.score |
408 | 444 | entity.prop = Prop.from_dict(entity_prop, "Chunk", None) |
409 | 445 | chunk_graph.append(entity) |
|
0 commit comments