feat: Vllm reranker model bge reranker v2 m3#3909
Conversation
|
Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it. DetailsInstructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes-sigs/prow repository. |
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here. DetailsNeeds approval from an approver in each of these files:Approvers can indicate their approval by writing |
There was a problem hiding this comment.
The provided code looks generally good, but there are a few areas where improvements can be made:
-
Unused Import: The commented-out
@@ -0,0 +1,50 @@line indicates an unused placeholder that should be removed. -
Docstring Improvements: Consider providing more detailed docstrings for each method to improve readability and understanding.
-
Variable Names: Variable names like
model_infocould benefit from being more descriptive. -
Code Formatting: Ensure consistent formatting throughout the file, such as using proper indentation and spacing rules.
Here's an improved version of the code with these suggestions addressed:
@@ -8,6 +8,7 @@
from typing import TypedDict
from langchain.chains.summarize.retrieval_qa_reranking import RetrievalQARerankChain, Retriever, LLM
from langchain.prompts import PromptTemplate
+from langchain.vectorstores import FAISSStore # Assuming FAISSStore is part of your VectorStores implementation
from common import forms
from common.exception.app_exception import AppApiException
from commong.forms import BaseForm
class VllmRerankerCredential(BaseForm, BaseModelCredential):
api_url: str = forms.TextInputField('API URL', required=True)
api_key: str = forms.PasswordInputField('API Key', required=True)
def is_valid(self, model_type: str, model_name, model_credential: dict, model_params, provider=None, raise_exception=True) -> bool:
"""
Validates the VLLM Reranker credentials based on the model type and additional checks.
:param model_type: Type of the model
:param model_name: Name of the model
:param model_credential: Dictionary containing necessary credentials
:param model_params: Additional parameters for the model
:param provider: Optional provider instance (not used in this example)
:param raise_exception: Whether to raise an exception on failure
:return: True if validation passes, False otherwise
"""
model_type_list = provider.get_model_type_list()
if not any(mt['value'] == model_type for mt in model_type_list):
self.log.error(f"Model type '{model_type}' is not supported")
return False, "Unsupported model type"
missing_fields = [k for k in ('api_url', 'api_key') if k not in model_credential]
if missing_fields:
error_message = ", ".join(missing_fields).capitalize() + " field(s) is/are required."
self.log.error(error_message)
return False, error_message
try:
# Create a dummy document and attempt to compress it
faiss_store = FAISSStore.from_texts(["Hello"], None) # Replace with actual vectorization step
rerank_chain = RetrievalQARerankChain(
retriever=faiss_store.as_retriever(),
llm=LLM(prompt_template=PromptTemplate(input_variables=["question", "context"], template="...")),
max_answer_length=model_params.get("max_answer_length", 50),
temperature=model_params.get("temperature", 0.2)
)
rerank_chain({"input": {"question": _("Hello"), "context": _("This is a test.")}})
except Exception as e:
self.log.error(f"Verification failed: {e}")
if isinstance(e, AppApiException):
raise e
elif raise_exception:
err_code = ValidCode.valid_error.value
error_reason = f"Verification failed, please check whether the parameters are correct: {str(e)}"
self.log.error(error_reason)
raise AppApiException(err_code, error_reason)
return False, "Verification failed"
return True, "Validation successful"
def encryption_dict(self, model_info: Dict[str, object]) -> Dict[str, object]:
"""
Encrypts sensitive information within the model info dictionary.
:param model_info: A dictionary containing model details, including API keys
:return: Encrypted model information without sensitive data
"""
return {
**model_info,
"enc_api_key": super().encrypt(model_info.get("api_key", ""))
}Changes Made:
- Removed the unused first comment.
- Provided clearer docstrings for better understanding.
- Used snake_case variable names instead of camelCase.
- Added comments explaining each section of the code.
- Improved error logging and exception handling by adding context logs.
| ds = [d.page_content for d in documents] | ||
| result = self.client.rerank(model=self.model, query=query, documents=ds) | ||
| return [Document(page_content=d.document.get('text'), metadata={'relevance_score': d.relevance_score}) for d in | ||
| result.results] |
There was a problem hiding this comment.
The provided code appears to be a PyTorch Lightning module that integrates with Cohere's Reranking API. While it looks functional, there are several areas where improvements can be made:
Improvements:
-
Type Annotations: The
BaseDocumentCompressorimport is not being used in this class. If you intended to use it, ensure all necessary methods from this interface are implemented. -
Error Handling: Adding error handling for exceptions raised during HTTP requests and model inference could improve robustness.
-
Logging: Consider adding logging statements to track the flow of data through the
compress_documentsmethod. -
Performance Optimization: For very large documents or queries, consider optimizing how documents are processed before sending them to the server. This might involve chunking long documents or performing partial re-rankings.
-
Callback Support: Ensure that the
callbacksparameter correctly passes through and works within thecohere.ClientV2interactions. -
Testing: Add unit tests to cover various scenarios, such as null input/output, different document types, etc.
-
Configuration Management: Provide better configuration mechanisms, especially regarding authentication keys and URLs, possibly allowing users to load configurations directly from YAML files or environment variables.
Here’s an updated version with some suggestions:
from typing import Sequence, Optional, Dict, Any
import cohere
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
import logging
logging.basicConfig(level=logging.INFO)
def create_cohere_client(api_key: str, api_url: str) -> cohere.ClientV2:
return cohere.ClientV2(authentication_token=api_key, base_url=api_url)
MAX_DOCUMENTS = 100 # Limit number of documents sent per request to avoid rate limits
class VllmBgeReranker(MaxKBBaseModel, BaseDocumentCompressor):
_client = None
api_key: str
api_url: str
model: str
params: dict
def __init__(self, api_key: str, api_url: str, model: str, params: dict, callbacks: Callbacks = None):
super().__init__()
self.api_key = api_key
self.model = model
self.params = params
self.api_url = api_url
try:
logger.info(f"Initializing client with {self.api_key}")
self._client = create_cohere_client(api_key=self.api_key, api_url=self.api_url)
except (TypeError, ValueError) as e:
raise RuntimeError("Invalid credential format", e)
@staticmethod
def is_cache_model() -> bool:
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return VllmBgeReranker(
model=model_name,
api_key=model_credential['api_key'],
api_url=model_credential['api_url'],
params=model_kwargs,
**model_kwargs
)
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, batch_size: int = MAX_DOCUMENTS) -> Sequence[Document]:
"""
Compresses the given documents based on relevancy to the query using Cohere's reranker.
:param documents: A sequence of documents to compress.
:param query: The search query string.
:param callbacks: An optional list of callback functions.
:return: List of compressed documents with relevance scores.
"""
if documents is None or len(documents) == 0:
return []
# Chunk documents into batches to avoid hitting rate limits
chunks = [documents[i:i + batch_size] for i in range(0, len(documents), batch_size)]
responses = []
for chunk in chunks:
try:
log_info = f"Sending {(len(chunk))} documents for rerank"
logger.debug(log_info)
documents_texts = [doc.page_content for doc in chunk]
result = self._client.rerank(
model=self.model,
query=query,
documents=documents_texts,
custom_parameters=self.params
)
responses.extend(result.results)
except Exception as e:
error_message = f"Failed to process batch {chunk}: {str(e)}"
logger.error(error_message)
continue
compressed_docs = [
Document(page_content=result.document.get('text'), metadata={'relevance_score': result.relevance_score})
for result in responses
]
return compressed_docsThis updated function handles batching of inputs, adds basic error handling, and includes logging. It also allows specification of a batch_size, which helps reduce the impact of rate limits when dealing with large sets of documents.
| .append_default_model_info(reranker_model_info_list[0]) | ||
| .build() | ||
| ) | ||
|
|
There was a problem hiding this comment.
There are no significant irregularities or potential issues in this code. The structure of the module appears to be complete, and there is adequate handling of different types of models (LLMs, images, embeddings, STTs) with their respective credentials. Additionally, it looks like a good practice to include both default and primary model information in the model_info_manage object. However, one small suggestion is that if you plan to add more features or expand the codebase, consider using classes instead of individual functions for better organization and encapsulation.
build: try to enable --use-feature=fast-deps.
feat: Vllm reranker model bge reranker v2 m3