From 86b1f4dec52e03a7eeb08eda99ddb08c1a1e8420 Mon Sep 17 00:00:00 2001 From: lately818 <1460735292@qq.com> Date: Sun, 3 May 2026 00:40:13 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BARAG=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F=E5=8A=9F=E8=83=BD=E5=B9=B6=E6=B7=BB=E5=8A=A0=E5=85=83?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit refactor(bm25): 重构BM25索引以支持文档元数据 feat(cross_encoder): 添加RerankResult类并更新预测接口 feat(rag): 实现RetrievedDocument和RAGResult数据结构 perf(vectorstore): 优化BM25和RRF融合算法 fix(chroma): 处理空元数据情况并添加日志 --- .../cross_encoder/cross_encoder_model.py | 24 +- .../huggingface_cross_encoder.py | 22 +- src/raglight/rag/rag.py | 380 +++++++++++++++--- src/raglight/vectorstore/bm25_index.py | 68 +++- src/raglight/vectorstore/chroma.py | 20 +- src/raglight/vectorstore/vector_store.py | 44 +- 6 files changed, 464 insertions(+), 94 deletions(-) diff --git a/src/raglight/cross_encoder/cross_encoder_model.py b/src/raglight/cross_encoder/cross_encoder_model.py index 2e422ce..bf14b66 100644 --- a/src/raglight/cross_encoder/cross_encoder_model.py +++ b/src/raglight/cross_encoder/cross_encoder_model.py @@ -1,6 +1,24 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, List +from dataclasses import dataclass +from typing import Any, List, Optional + + +@dataclass +class RerankResult: + """ + Represents a single reranked document result with metadata and score. + + Attributes: + text (str): The document text content. + score (float): The relevance score from the cross encoder. + corpus_id (int): The original index in the input document list. + metadata (Optional[dict]): Original document metadata (source, etc.). + """ + text: str + score: float + corpus_id: int + metadata: Optional[dict] = None class CrossEncoderModel(ABC): @@ -48,7 +66,7 @@ def get_model(self) -> CrossEncoderModel: return self.model @abstractmethod - def predict(self, query: str, documents: List[str], top_k: int) -> List[str]: + def predict(self, query: str, documents: List[str], top_k: int) -> List[RerankResult]: """ Re-ranks the given documents against the query and returns the top_k most relevant. @@ -58,6 +76,6 @@ def predict(self, query: str, documents: List[str], top_k: int) -> List[str]: top_k (int): The number of top results to return. Returns: - List[str]: The top_k re-ranked document texts. + List[RerankResult]: The top_k re-ranked results with scores and corpus IDs. """ pass diff --git a/src/raglight/cross_encoder/huggingface_cross_encoder.py b/src/raglight/cross_encoder/huggingface_cross_encoder.py index d016347..4445aaa 100644 --- a/src/raglight/cross_encoder/huggingface_cross_encoder.py +++ b/src/raglight/cross_encoder/huggingface_cross_encoder.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import List from typing_extensions import override -from .cross_encoder_model import CrossEncoderModel +from .cross_encoder_model import CrossEncoderModel, RerankResult from sentence_transformers import CrossEncoder @@ -39,9 +39,9 @@ def load(self) -> HuggingfaceCrossEncoderModel: return CrossEncoder(self.model_name) @override - def predict(self, query: str, documents: List[str], top_k: int) -> List[str]: + def predict(self, query: str, documents: List[str], top_k: int) -> List[RerankResult]: """ - Predicts the similarity scores and returns the list of most relevant document texts. + Predicts the similarity scores and returns the list of most relevant document results. Args: query (str): The input query. @@ -49,12 +49,22 @@ def predict(self, query: str, documents: List[str], top_k: int) -> List[str]: top_k (int): The number of top results to return. Returns: - List[str]: The list of top_k re-ranked document texts. + List[RerankResult]: The list of top_k re-ranked results with scores and corpus IDs. """ + if not documents: + return [] + # rank returns a list of dicts: [{'corpus_id': int, 'score': float, 'text': str}, ...] results = self.model.rank( query=query, documents=documents, top_k=top_k, return_documents=True ) - # We extract and return only the text strings - return [res["text"] for res in results] + # Convert to RerankResult objects with all metadata + return [ + RerankResult( + text=res["text"], + score=float(res["score"]), + corpus_id=int(res["corpus_id"]) + ) + for res in results + ] diff --git a/src/raglight/rag/rag.py b/src/raglight/rag/rag.py index 063ae1a..6060489 100644 --- a/src/raglight/rag/rag.py +++ b/src/raglight/rag/rag.py @@ -3,6 +3,7 @@ import logging import os import uuid +from dataclasses import dataclass, field from typing import Any, Iterable, Optional from langchain_core.documents import Document @@ -18,6 +19,61 @@ logger = logging.getLogger(__name__) +@dataclass +class RetrievedDocument: + """ + Represents a retrieved document with all relevant metadata and scores. + + Attributes: + content (str): The document text content. + source (str): The source of the document (file path, URL, etc.). + metadata (Dict[str, Any]): Full metadata dictionary. + bm25_score (Optional[float]): BM25 relevance score if applicable. + rerank_score (Optional[float]): Cross-encoder rerank score if applicable. + rrf_score (Optional[float]): RRF fusion score if applicable. + """ + content: str + source: str = "Unknown" + metadata: Dict[str, Any] = field(default_factory=dict) + bm25_score: Optional[float] = None + rerank_score: Optional[float] = None + rrf_score: Optional[float] = None + + @classmethod + def from_langchain_doc(cls, doc: Document) -> "RetrievedDocument": + """Create a RetrievedDocument from a LangChain Document.""" + metadata = doc.metadata or {} + return cls( + content=doc.page_content, + source=metadata.get("source", "Unknown"), + metadata=dict(metadata), + bm25_score=metadata.get("bm25_score"), + rerank_score=metadata.get("rerank_score"), + rrf_score=metadata.get("rrf_score") or metadata.get("rrf_combined_score"), + ) + + +@dataclass +class RAGResult: + """ + Complete result from a RAG pipeline execution. + + Attributes: + answer (str): The generated answer from the LLM. + retrieved_docs (List[RetrievedDocument]): All documents retrieved during the process. + reformulated_question (Optional[str]): The reformulated question if query rewriting was used. + original_question (str): The original user question. + has_evidence (bool): Whether relevant evidence was found in the knowledge base. + error_message (Optional[str]): Error message if any stage failed and fallback was used. + """ + answer: str + retrieved_docs: List[RetrievedDocument] = field(default_factory=list) + reformulated_question: Optional[str] = None + original_question: str = "" + has_evidence: bool = True + error_message: Optional[str] = None + + class State(TypedDict): """ Represents the state of the RAG process. @@ -27,12 +83,16 @@ class State(TypedDict): context (List[Document]): A list of documents retrieved from the vector store as context. answer (str): The generated answer based on the input question and context. history (List[Dict[str, str]]): The history of the conversation. + reformulated_question (Optional[str]): The reformulated question if query rewriting was used. + error_message (Optional[str]): Error message if any stage failed. """ question: str answer: str context: List[Document] = [] history: List[Dict[str, str]] = [] + reformulated_question: Optional[str] = None + error_message: Optional[str] = None class RAG: @@ -92,36 +152,63 @@ def __init__( self._createGraph() ) # Here type is CompiledGraph but it's not exposed by https://github.com/langchain-ai/langgraph/blob/main/libs/langgraph/langgraph/graph/graph.py - def _reformulate(self, state: State) -> Dict[str, str]: + def _reformulate(self, state: State) -> Dict[str, Any]: """ Rewrites the question as a standalone question using the conversation history. - If there is no history, the original question is returned unchanged. + If there is no history or if reformulation fails, the original question is returned unchanged. Args: state (State): Current pipeline state with 'question' and 'history'. Returns: - Dict[str, str]: Updated state with the reformulated question. + Dict[str, Any]: Updated state with the reformulated question and reformulated_question field. """ + original_question = state["question"] + if not state["history"]: - return {"question": state["question"]} - - history_text = "\n".join( - f"{msg['role'].capitalize()}: {msg['content']}" for msg in state["history"] - ) - prompt = ( - f"Given the following conversation history and a follow-up question, " - f"rewrite the follow-up question as a standalone question that captures all necessary context.\n\n" - f"Conversation history:\n{history_text}\n\n" - f"Follow-up question: {state['question']}\n\n" - f"Standalone question (output ONLY the reformulated question, nothing else):" - ) - reformulated = self.llm.generate({"question": prompt, "history": []}) - logger.info(f"Reformulated question: {reformulated.strip()}") - return {"question": reformulated.strip()} + logger.info("No conversation history, skipping query reformulation") + return { + "question": original_question, + "reformulated_question": None + } - def _retrieve(self, state: State) -> Dict[str, List[Document]]: + try: + history_text = "\n".join( + f"{msg['role'].capitalize()}: {msg['content']}" for msg in state["history"] + ) + prompt = ( + f"Given the following conversation history and a follow-up question, " + f"rewrite the follow-up question as a standalone question that captures all necessary context.\n\n" + f"Conversation history:\n{history_text}\n\n" + f"Follow-up question: {original_question}\n\n" + f"Standalone question (output ONLY the reformulated question, nothing else):" + ) + reformulated = self.llm.generate({"question": prompt, "history": []}) + reformulated_question = reformulated.strip() + + if not reformulated_question or len(reformulated_question) < 3: + logger.warning(f"Reformulated question is empty or too short: '{reformulated_question}', using original") + return { + "question": original_question, + "reformulated_question": None + } + + logger.info(f"Query reformulated: '{original_question}' -> '{reformulated_question}'") + return { + "question": reformulated_question, + "reformulated_question": reformulated_question + } + + except Exception as e: + logger.error(f"Query reformulation failed, using original question. Error: {e}") + return { + "question": original_question, + "reformulated_question": None, + "error_message": f"Reformulation failed: {str(e)}" + } + + def _retrieve(self, state: State) -> Dict[str, Any]: """ Retrieves relevant documents based on the input question. @@ -129,25 +216,94 @@ def _retrieve(self, state: State) -> Dict[str, List[Document]]: state (Dict[str, str]): A dictionary containing the input question under the key 'question'. Returns: - Dict[str, List[Document]]: A dictionary containing the retrieved documents under the key 'context'. + Dict[str, Any]: A dictionary containing the retrieved documents under the key 'context'. """ - retrieved_docs = self.vector_store.similarity_search( - state["question"], k=self.k - ) - return {"context": retrieved_docs, "question": state["question"]} + question = state["question"] + try: + retrieved_docs = self.vector_store.similarity_search( + question, k=self.k + ) + + if not retrieved_docs: + logger.warning(f"No documents retrieved for query: '{question}'") + else: + logger.info(f"Retrieved {len(retrieved_docs)} documents for query: '{question}'") + + return { + "context": retrieved_docs, + "question": question + } + + except Exception as e: + logger.error(f"Retrieval failed for query: '{question}'. Error: {e}") + return { + "context": [], + "question": question, + "error_message": f"Retrieval failed: {str(e)}" + } def _build_prompt(self, state: Dict) -> str: - docs_content = "\n\n".join(doc.page_content for doc in state["context"]) + """ + Builds a prompt that includes context with source citations and requires evidence-based answers. + """ + context_docs = state.get("context", []) + + if not context_docs: + return f""" +You are a helpful assistant. The user asked: {state["question"]} + +IMPORTANT: No relevant documents were found in the knowledge base to answer this question. + +Please respond exactly as follows: +"无法根据知识库中的内容回答此问题。知识库中没有找到与该问题相关的信息。" +""" + + context_sections = [] + for idx, doc in enumerate(context_docs, 1): + metadata = doc.metadata if doc.metadata else {} + source = metadata.get("source", "Unknown") + score_info = [] + + if "rerank_score" in metadata: + score_info.append(f"rerank_score={metadata['rerank_score']:.4f}") + if "bm25_score" in metadata: + score_info.append(f"bm25_score={metadata['bm25_score']:.4f}") + if "rrf_score" in metadata: + score_info.append(f"rrf_score={metadata['rrf_score']:.4f}") + if "rrf_combined_score" in metadata: + score_info.append(f"rrf_combined={metadata['rrf_combined_score']:.4f}") + + score_str = ", ".join(score_info) if score_info else "N/A" + + context_sections.append(f"""--- +[Document {idx}] +Source: {source} +Relevance Scores: {score_str} +Content: +{doc.page_content} +---""") + + context_str = "\n".join(context_sections) + return f""" - Here is the retrieved context (excerpts from the document): - {docs_content} +You are an evidence-based assistant. Your answers MUST be strictly based on the provided context documents. - Here is the question: - {state["question"]} +## Retrieved Context Documents: +{context_str} +## User Question: +{state["question"]} - FINAL ANSWER (based only on the context): - """ +## Instructions: +1. **ONLY use information explicitly stated in the retrieved context documents** +2. **Cite your sources** using [n] notation where n is the document number (e.g., "According to [1], ...") +3. If the context does not contain enough information to answer the question, respond EXACTLY with: + "无法根据知识库中的内容回答此问题。知识库中没有找到与该问题相关的信息。" +4. Do NOT guess, fabricate, or use outside knowledge +5. If multiple documents contain relevant information, cite all relevant sources + +## Final Answer (based only on the context): +""" def _generate_graph(self, state: Dict[str, List[Document]]) -> Dict[str, str]: """ @@ -168,6 +324,7 @@ def _generate_graph(self, state: Dict[str, List[Document]]) -> Dict[str, str]: def _rerank(self, state: Dict[str, List[Document]]) -> Dict[str, List[Document]]: """ Reranks the retrieved documents based on the cross-encoder model. + Preserves original metadata and adds rerank scores. Args: state (Dict[str, List[Document]]): A dictionary containing the list of retrieved documents under the key 'context'. @@ -178,16 +335,33 @@ def _rerank(self, state: Dict[str, List[Document]]) -> Dict[str, List[Document]] try: question = state["question"] docs = state["context"] + + if not docs: + logger.warning("No documents to rerank, returning empty context") + return {"context": [], "question": state["question"]} + doc_texts = [doc.page_content for doc in docs] + top_k = max(1, int(self.k / 4)) - ranked_texts = self.cross_encoder.predict( - question, doc_texts, int(self.k / 4) + ranked_results = self.cross_encoder.predict( + question, doc_texts, top_k ) - ranked_docs = [Document(page_content=text) for text in ranked_texts] + ranked_docs = [] + for result in ranked_results: + original_doc = docs[result.corpus_id] + new_metadata = dict(original_doc.metadata) if original_doc.metadata else {} + new_metadata["rerank_score"] = result.score + new_metadata["original_index"] = result.corpus_id + ranked_docs.append(Document( + page_content=result.text, + metadata=new_metadata + )) + + logger.info(f"Rerank: {len(docs)} -> {len(ranked_docs)} documents preserved with metadata") except Exception as e: - logger.warning(f"Reranking failed: {e}") + logger.error(f"Reranking failed, falling back to original context. Error: {e}") ranked_docs = state["context"] return {"context": ranked_docs, "question": state["question"]} @@ -250,25 +424,79 @@ def generate(self, question: str) -> str: Returns: str: The generated answer from the pipeline. """ + result = self.generate_with_result(question) + return result.answer + + def generate_with_result(self, question: str) -> RAGResult: + """ + Executes the RAG pipeline for a given question and returns a complete RAGResult. + + Args: + question (str): The input question. + + Returns: + RAGResult: Complete result containing answer, retrieved documents, and metadata. + """ self.state["question"] = question + original_question = question if self.max_history is not None: self.state["history"] = self.state["history"][-self.max_history :] - if self.langfuse_config: - callback = self._build_langfuse_callback() - response = self.graph.invoke(self.state, config={"callbacks": [callback]}) - else: - response = self.graph.invoke(self.state) - - answer = response["answer"] - self.state["history"].extend( - [ - {"role": "user", "content": question}, - {"role": "assistant", "content": answer}, + try: + if self.langfuse_config: + callback = self._build_langfuse_callback() + response = self.graph.invoke(self.state, config={"callbacks": [callback]}) + else: + response = self.graph.invoke(self.state) + + answer = response["answer"] + context_docs = response.get("context", []) + reformulated_question = response.get("reformulated_question") + error_message = response.get("error_message") + + retrieved_docs = [ + RetrievedDocument.from_langchain_doc(doc) for doc in context_docs ] - ) - return answer + + has_evidence = len(retrieved_docs) > 0 + + self.state["history"].extend( + [ + {"role": "user", "content": original_question}, + {"role": "assistant", "content": answer}, + ] + ) + + return RAGResult( + answer=answer, + retrieved_docs=retrieved_docs, + reformulated_question=reformulated_question, + original_question=original_question, + has_evidence=has_evidence, + error_message=error_message, + ) + + except Exception as e: + logger.error(f"RAG pipeline failed for question: '{question}'. Error: {e}") + + fallback_answer = "无法根据知识库中的内容回答此问题。处理过程中发生错误。" + + self.state["history"].extend( + [ + {"role": "user", "content": original_question}, + {"role": "assistant", "content": fallback_answer}, + ] + ) + + return RAGResult( + answer=fallback_answer, + retrieved_docs=[], + reformulated_question=None, + original_question=original_question, + has_evidence=False, + error_message=f"Pipeline failed: {str(e)}", + ) def generate_streaming(self, question: str) -> Iterable[str]: """ @@ -292,10 +520,21 @@ def generate_streaming(self, question: str) -> Iterable[str]: "history": list(self.state["history"]), } - if self.reformulation: - state.update(self._reformulate(state)) + reformulated_question = None + error_messages = [] - state.update(self._retrieve(state)) + if self.reformulation: + reform_result = self._reformulate(state) + state.update(reform_result) + if reform_result.get("reformulated_question"): + reformulated_question = reform_result["reformulated_question"] + if reform_result.get("error_message"): + error_messages.append(reform_result["error_message"]) + + retrieve_result = self._retrieve(state) + state.update(retrieve_result) + if retrieve_result.get("error_message"): + error_messages.append(retrieve_result["error_message"]) if self.cross_encoder: state.update(self._rerank(state)) @@ -304,16 +543,29 @@ def generate_streaming(self, question: str) -> Iterable[str]: callbacks = [self._build_langfuse_callback()] if self.langfuse_config else None - full_answer = "" - for chunk in self.llm.generate_streaming( - {"question": prompt, "history": state["history"]}, callbacks=callbacks - ): - full_answer += chunk - yield chunk - - self.state["history"].extend( - [ - {"role": "user", "content": question}, - {"role": "assistant", "content": full_answer}, - ] - ) + try: + full_answer = "" + for chunk in self.llm.generate_streaming( + {"question": prompt, "history": state["history"]}, callbacks=callbacks + ): + full_answer += chunk + yield chunk + + self.state["history"].extend( + [ + {"role": "user", "content": question}, + {"role": "assistant", "content": full_answer}, + ] + ) + + except Exception as e: + logger.error(f"Streaming generation failed. Error: {e}") + fallback_answer = "无法根据知识库中的内容回答此问题。处理过程中发生错误。" + yield fallback_answer + + self.state["history"].extend( + [ + {"role": "user", "content": question}, + {"role": "assistant", "content": fallback_answer}, + ] + ) diff --git a/src/raglight/vectorstore/bm25_index.py b/src/raglight/vectorstore/bm25_index.py index d85777d..47564b3 100644 --- a/src/raglight/vectorstore/bm25_index.py +++ b/src/raglight/vectorstore/bm25_index.py @@ -1,43 +1,87 @@ from __future__ import annotations import json import re +from dataclasses import dataclass, asdict from pathlib import Path -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Dict, Any from rank_bm25 import BM25Okapi +@dataclass +class BM25Document: + """Represents a document in the BM25 index with text and metadata.""" + text: str + metadata: Dict[str, Any] + + class BM25Index: - """Lightweight BM25 index over a list of text documents.""" + """Lightweight BM25 index over a list of text documents with metadata support.""" def __init__(self) -> None: - self.corpus: List[str] = [] + self.documents: List[BM25Document] = [] self._bm25: Optional[BM25Okapi] = None def _tokenize(self, text: str) -> List[str]: return re.findall(r"\w+", text.lower()) def _rebuild(self) -> None: - if self.corpus: - self._bm25 = BM25Okapi([self._tokenize(t) for t in self.corpus]) + if self.documents: + self._bm25 = BM25Okapi([self._tokenize(doc.text) for doc in self.documents]) else: self._bm25 = None - def add_documents(self, texts: List[str]) -> None: - self.corpus.extend(texts) + def add_documents(self, texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None) -> None: + """ + Add documents to the BM25 index. + + Args: + texts: List of document text contents. + metadatas: Optional list of metadata dictionaries corresponding to each text. + """ + if metadatas is None: + metadatas = [{} for _ in texts] + + # Ensure metadatas length matches texts + if len(metadatas) != len(texts): + metadatas = metadatas + [{}] * (len(texts) - len(metadatas)) + + for text, metadata in zip(texts, metadatas): + self.documents.append(BM25Document(text=text, metadata=metadata or {})) + self._rebuild() - def search(self, query: str, k: int) -> List[Tuple[int, float]]: - if not self._bm25 or not self.corpus: + def search(self, query: str, k: int) -> List[Tuple[int, float, str, Dict[str, Any]]]: + """ + Search the BM25 index and return results with metadata. + + Args: + query: The search query. + k: Number of top results to return. + + Returns: + List of tuples: (index, score, text, metadata) + """ + if not self._bm25 or not self.documents: return [] + tokens = self._tokenize(query) scores = self._bm25.get_scores(tokens) indexed = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) - return [(idx, score) for idx, score in indexed[:k]] + + results = [] + for idx, score in indexed[:k]: + if idx < len(self.documents): + doc = self.documents[idx] + results.append((idx, score, doc.text, doc.metadata)) + + return results def save(self, path: Path) -> None: - path.write_text(json.dumps(self.corpus, ensure_ascii=False), encoding="utf-8") + data = [asdict(doc) for doc in self.documents] + path.write_text(json.dumps(data, ensure_ascii=False), encoding="utf-8") def load(self, path: Path) -> None: - self.corpus = json.loads(path.read_text(encoding="utf-8")) + data = json.loads(path.read_text(encoding="utf-8")) + self.documents = [BM25Document(**item) for item in data] self._rebuild() diff --git a/src/raglight/vectorstore/chroma.py b/src/raglight/vectorstore/chroma.py index c36849a..e6aa0b4 100644 --- a/src/raglight/vectorstore/chroma.py +++ b/src/raglight/vectorstore/chroma.py @@ -12,6 +12,8 @@ from .vector_store import VectorStore from ..embeddings.embeddings_model import EmbeddingsModel +logger = logging.getLogger(__name__) + class ChromaEmbeddingAdapter(EmbeddingFunction): """ @@ -84,8 +86,24 @@ def __init__( def _rebuild_bm25_from_chroma(self) -> None: result = self.collection.get() texts = result.get("documents") or [] + metadatas = result.get("metadatas") or [] + if texts: - self._bm25.add_documents(texts) + if not metadatas or len(metadatas) != len(texts): + metadatas = [{} for _ in texts] + + # Convert None metadata to empty dict + processed_metadatas = [] + for meta in metadatas: + if meta is None: + processed_metadatas.append({}) + elif isinstance(meta, dict): + processed_metadatas.append(dict(meta)) + else: + processed_metadatas.append({"raw_meta": str(meta)}) + + self._bm25.add_documents(texts, processed_metadatas) + logger.info(f"Rebuilt BM25 index from Chroma: {len(texts)} documents with metadata") @override def add_documents(self, documents: List[Document]) -> None: diff --git a/src/raglight/vectorstore/vector_store.py b/src/raglight/vectorstore/vector_store.py index e3f04b1..768eabc 100644 --- a/src/raglight/vectorstore/vector_store.py +++ b/src/raglight/vectorstore/vector_store.py @@ -13,6 +13,8 @@ from ..config.settings import Settings from .bm25_index import BM25Index +logger = logging.getLogger(__name__) + class VectorStore(ABC): """ @@ -51,7 +53,8 @@ def _bm25_path(self) -> Optional[Path]: def _update_bm25(self, documents: List[Document]) -> None: texts = [doc.page_content for doc in documents] - self._bm25.add_documents(texts) + metadatas = [dict(doc.metadata) if doc.metadata else {} for doc in documents] + self._bm25.add_documents(texts, metadatas) bm25_path = self._bm25_path() if bm25_path: self._bm25.save(bm25_path) @@ -59,9 +62,12 @@ def _update_bm25(self, documents: List[Document]) -> None: def _bm25_search(self, question: str, k: int) -> List[Document]: results = self._bm25.search(question, k) docs = [] - for idx, _score in results: - if idx < len(self._bm25.corpus): - docs.append(Document(page_content=self._bm25.corpus[idx])) + for idx, score, text, metadata in results: + new_metadata = dict(metadata) if metadata else {} + new_metadata["bm25_score"] = score + new_metadata["bm25_index"] = idx + docs.append(Document(page_content=text, metadata=new_metadata)) + logger.info(f"BM25 search: found {len(docs)} documents with metadata") return docs def _rrf( @@ -69,13 +75,35 @@ def _rrf( ) -> List[Document]: scores: Dict[str, float] = {} doc_map: Dict[str, Document] = {} + for ranked in ranked_lists: for rank, doc in enumerate(ranked): - key = doc.page_content[:100] - scores[key] = scores.get(key, 0) + 1 / (k_rrf + rank + 1) - doc_map[key] = doc + key = doc.page_content[:200] + rrf_score = 1 / (k_rrf + rank + 1) + + if key in scores: + scores[key] += rrf_score + existing_doc = doc_map[key] + merged_metadata = dict(existing_doc.metadata) + merged_metadata.update(doc.metadata) + merged_metadata["rrf_combined_score"] = scores[key] + doc_map[key] = Document( + page_content=existing_doc.page_content, + metadata=merged_metadata + ) + else: + scores[key] = rrf_score + new_metadata = dict(doc.metadata) if doc.metadata else {} + new_metadata["rrf_score"] = rrf_score + doc_map[key] = Document( + page_content=doc.page_content, + metadata=new_metadata + ) + sorted_keys = sorted(scores, key=scores.__getitem__, reverse=True) - return [doc_map[k] for k in sorted_keys] + result = [doc_map[k] for k in sorted_keys] + logger.info(f"RRF fusion: merged {len(ranked_lists)} ranked lists into {len(result)} unique documents") + return result def _hybrid_search( self, question: str, k: int, filter: Optional[Dict[str, Any]] From e09ffcb39c4caf234c0d6825791c2f6ced8575bd Mon Sep 17 00:00:00 2001 From: lately818 <1460735292@qq.com> Date: Sun, 3 May 2026 01:18:31 +0800 Subject: [PATCH 2/2] =?UTF-8?q?feat(=E5=90=91=E9=87=8F=E5=AD=98=E5=82=A8):?= =?UTF-8?q?=20=E4=B8=BA=E6=96=87=E6=A1=A3=E6=A3=80=E7=B4=A2=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E9=98=B6=E6=AE=B5=E8=BF=BD=E8=B8=AA=E5=85=83=E6=95=B0?= =?UTF-8?q?=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在Chroma、Qdrant和BM25检索过程中添加retrieval_stage元数据标记 新增retrieval_stages列表记录文档经过的检索阶段 保持BM25向后兼容性同时新增search_with_metadata方法 改进元数据处理逻辑,确保字典深拷贝 --- src/raglight/rag/rag.py | 8 +++++ src/raglight/vectorstore/bm25_index.py | 37 +++++++++++++++++++++--- src/raglight/vectorstore/chroma.py | 3 +- src/raglight/vectorstore/qdrant.py | 28 +++++++++++++----- src/raglight/vectorstore/vector_store.py | 16 +++++++++- 5 files changed, 78 insertions(+), 14 deletions(-) diff --git a/src/raglight/rag/rag.py b/src/raglight/rag/rag.py index 6060489..47869b3 100644 --- a/src/raglight/rag/rag.py +++ b/src/raglight/rag/rag.py @@ -353,6 +353,14 @@ def _rerank(self, state: Dict[str, List[Document]]) -> Dict[str, List[Document]] new_metadata = dict(original_doc.metadata) if original_doc.metadata else {} new_metadata["rerank_score"] = result.score new_metadata["original_index"] = result.corpus_id + + original_stage = new_metadata.get("retrieval_stage", "unknown") + retrieval_stages = new_metadata.get("retrieval_stages", []) + if original_stage not in retrieval_stages: + retrieval_stages.append(original_stage) + new_metadata["retrieval_stages"] = retrieval_stages + new_metadata["retrieval_stage"] = "reranked" + ranked_docs.append(Document( page_content=result.text, metadata=new_metadata diff --git a/src/raglight/vectorstore/bm25_index.py b/src/raglight/vectorstore/bm25_index.py index 47564b3..1550c60 100644 --- a/src/raglight/vectorstore/bm25_index.py +++ b/src/raglight/vectorstore/bm25_index.py @@ -16,12 +16,22 @@ class BM25Document: class BM25Index: - """Lightweight BM25 index over a list of text documents with metadata support.""" + """ + Lightweight BM25 index over a list of text documents with metadata support. + + Backward compatible: search() returns (index, score) for existing tests. + Use search_with_metadata() for full results with text and metadata. + """ def __init__(self) -> None: self.documents: List[BM25Document] = [] self._bm25: Optional[BM25Okapi] = None + @property + def corpus(self) -> List[str]: + """Backward compatible: return list of text contents.""" + return [doc.text for doc in self.documents] + def _tokenize(self, text: str) -> List[str]: return re.findall(r"\w+", text.lower()) @@ -42,7 +52,6 @@ def add_documents(self, texts: List[str], metadatas: Optional[List[Dict[str, Any if metadatas is None: metadatas = [{} for _ in texts] - # Ensure metadatas length matches texts if len(metadatas) != len(texts): metadatas = metadatas + [{}] * (len(texts) - len(metadatas)) @@ -51,9 +60,29 @@ def add_documents(self, texts: List[str], metadatas: Optional[List[Dict[str, Any self._rebuild() - def search(self, query: str, k: int) -> List[Tuple[int, float, str, Dict[str, Any]]]: + def search(self, query: str, k: int) -> List[Tuple[int, float]]: + """ + Search the BM25 index (backward compatible). + + Args: + query: The search query. + k: Number of top results to return. + + Returns: + List of tuples: (index, score) - backward compatible format + """ + if not self._bm25 or not self.documents: + return [] + + tokens = self._tokenize(query) + scores = self._bm25.get_scores(tokens) + indexed = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) + + return [(idx, score) for idx, score in indexed[:k]] + + def search_with_metadata(self, query: str, k: int) -> List[Tuple[int, float, str, Dict[str, Any]]]: """ - Search the BM25 index and return results with metadata. + Search the BM25 index and return results with full metadata. Args: query: The search query. diff --git a/src/raglight/vectorstore/chroma.py b/src/raglight/vectorstore/chroma.py index e6aa0b4..8a4786e 100644 --- a/src/raglight/vectorstore/chroma.py +++ b/src/raglight/vectorstore/chroma.py @@ -191,7 +191,8 @@ def _query_collection( else [{}] * len(docs_list) ) for text, meta in zip(docs_list, metas_list): - safe_meta = meta if isinstance(meta, dict) else {} + safe_meta = dict(meta) if isinstance(meta, dict) else {} + safe_meta["retrieval_stage"] = "semantic" found_docs.append(Document(page_content=text, metadata=safe_meta)) return found_docs diff --git a/src/raglight/vectorstore/qdrant.py b/src/raglight/vectorstore/qdrant.py index c278a2b..f774d12 100644 --- a/src/raglight/vectorstore/qdrant.py +++ b/src/raglight/vectorstore/qdrant.py @@ -10,6 +10,8 @@ from .vector_store import VectorStore from ..embeddings.embeddings_model import EmbeddingsModel +logger = logging.getLogger(__name__) + class QdrantVS(VectorStore): """ @@ -73,15 +75,24 @@ def _rebuild_bm25_from_qdrant(self) -> None: records, _ = self.client.scroll( collection_name=self.collection_name, limit=10_000, with_payload=True ) - texts = [ - r.payload.get("page_content", "") - for r in records - if r.payload and r.payload.get("page_content") - ] + + texts = [] + metadatas = [] + + for r in records: + if r.payload and r.payload.get("page_content"): + page_content = r.payload.get("page_content", "") + texts.append(page_content) + + metadata = dict(r.payload) + metadata.pop("page_content", None) + metadatas.append(metadata) + if texts: - self._bm25.add_documents(texts) + self._bm25.add_documents(texts, metadatas) + logger.info(f"Rebuilt BM25 index from Qdrant: {len(texts)} documents with metadata") except Exception as e: - logging.warning(f"Could not rebuild BM25 from Qdrant: {e}") + logger.warning(f"Could not rebuild BM25 from Qdrant: {e}") def _ensure_collection(self, name: str) -> None: from qdrant_client.models import Distance, VectorParams @@ -146,8 +157,9 @@ def _semantic_search( docs = [] for hit in results: - payload = hit.payload or {} + payload = dict(hit.payload) if hit.payload else {} page_content = payload.pop("page_content", "") + payload["retrieval_stage"] = "semantic" docs.append(Document(page_content=page_content, metadata=payload)) return docs diff --git a/src/raglight/vectorstore/vector_store.py b/src/raglight/vectorstore/vector_store.py index 768eabc..02d2af7 100644 --- a/src/raglight/vectorstore/vector_store.py +++ b/src/raglight/vectorstore/vector_store.py @@ -60,12 +60,13 @@ def _update_bm25(self, documents: List[Document]) -> None: self._bm25.save(bm25_path) def _bm25_search(self, question: str, k: int) -> List[Document]: - results = self._bm25.search(question, k) + results = self._bm25.search_with_metadata(question, k) docs = [] for idx, score, text, metadata in results: new_metadata = dict(metadata) if metadata else {} new_metadata["bm25_score"] = score new_metadata["bm25_index"] = idx + new_metadata["retrieval_stage"] = "bm25" docs.append(Document(page_content=text, metadata=new_metadata)) logger.info(f"BM25 search: found {len(docs)} documents with metadata") return docs @@ -85,8 +86,16 @@ def _rrf( scores[key] += rrf_score existing_doc = doc_map[key] merged_metadata = dict(existing_doc.metadata) + + retrieval_stages = merged_metadata.get("retrieval_stages", []) + current_stage = doc.metadata.get("retrieval_stage", "unknown") + if current_stage not in retrieval_stages: + retrieval_stages.append(current_stage) + merged_metadata["retrieval_stages"] = retrieval_stages + merged_metadata.update(doc.metadata) merged_metadata["rrf_combined_score"] = scores[key] + merged_metadata["retrieval_stage"] = "hybrid" doc_map[key] = Document( page_content=existing_doc.page_content, metadata=merged_metadata @@ -95,6 +104,11 @@ def _rrf( scores[key] = rrf_score new_metadata = dict(doc.metadata) if doc.metadata else {} new_metadata["rrf_score"] = rrf_score + + retrieval_stage = new_metadata.get("retrieval_stage", "unknown") + new_metadata["retrieval_stages"] = [retrieval_stage] + new_metadata["retrieval_stage"] = "hybrid" + doc_map[key] = Document( page_content=doc.page_content, metadata=new_metadata