Skip to content

feat: Vllm reranker model bge reranker v2 m3#3909

Merged
zhanweizhang7 merged 2 commits intov2from
pr@v2@feat_vllm_reranker
Aug 21, 2025
Merged

feat: Vllm reranker model bge reranker v2 m3#3909
zhanweizhang7 merged 2 commits intov2from
pr@v2@feat_vllm_reranker

Conversation

@shaohuzhang1
Copy link
Copy Markdown
Contributor

build: try to enable --use-feature=fast-deps.
feat: Vllm reranker model bge reranker v2 m3

@f2c-ci-robot
Copy link
Copy Markdown

f2c-ci-robot bot commented Aug 21, 2025

Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it.

Details

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes-sigs/prow repository.

@f2c-ci-robot
Copy link
Copy Markdown

f2c-ci-robot bot commented Aug 21, 2025

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:

The full list of commands accepted by this bot can be found here.

Details Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code looks generally good, but there are a few areas where improvements can be made:

  1. Unused Import: The commented-out @@ -0,0 +1,50 @@ line indicates an unused placeholder that should be removed.

  2. Docstring Improvements: Consider providing more detailed docstrings for each method to improve readability and understanding.

  3. Variable Names: Variable names like model_info could benefit from being more descriptive.

  4. Code Formatting: Ensure consistent formatting throughout the file, such as using proper indentation and spacing rules.

Here's an improved version of the code with these suggestions addressed:

@@ -8,6 +8,7 @@
from typing import TypedDict

from langchain.chains.summarize.retrieval_qa_reranking import RetrievalQARerankChain, Retriever, LLM
from langchain.prompts import PromptTemplate
+from langchain.vectorstores import FAISSStore  # Assuming FAISSStore is part of your VectorStores implementation

from common import forms
from common.exception.app_exception import AppApiException
from commong.forms import BaseForm


class VllmRerankerCredential(BaseForm, BaseModelCredential):
    api_url: str = forms.TextInputField('API URL', required=True)
    api_key: str = forms.PasswordInputField('API Key', required=True)

    def is_valid(self, model_type: str, model_name, model_credential: dict, model_params, provider=None, raise_exception=True) -> bool:
        """
        Validates the VLLM Reranker credentials based on the model type and additional checks.

        :param model_type: Type of the model
        :param model_name: Name of the model
        :param model_credential: Dictionary containing necessary credentials
        :param model_params: Additional parameters for the model
        :param provider: Optional provider instance (not used in this example)
        :param raise_exception: Whether to raise an exception on failure
        :return: True if validation passes, False otherwise
        """
        model_type_list = provider.get_model_type_list()
        if not any(mt['value'] == model_type for mt in model_type_list):
            self.log.error(f"Model type '{model_type}' is not supported")
            return False, "Unsupported model type"

        missing_fields = [k for k in ('api_url', 'api_key') if k not in model_credential]
        if missing_fields:
            error_message = ", ".join(missing_fields).capitalize() + " field(s) is/are required."
            self.log.error(error_message)
            return False, error_message

        try:
            # Create a dummy document and attempt to compress it
            faiss_store = FAISSStore.from_texts(["Hello"], None)  # Replace with actual vectorization step
            rerank_chain = RetrievalQARerankChain(
                retriever=faiss_store.as_retriever(),
                llm=LLM(prompt_template=PromptTemplate(input_variables=["question", "context"], template="...")),
                max_answer_length=model_params.get("max_answer_length", 50),
                temperature=model_params.get("temperature", 0.2)
            )
            rerank_chain({"input": {"question": _("Hello"), "context": _("This is a test.")}})
        except Exception as e:
            self.log.error(f"Verification failed: {e}")
            if isinstance(e, AppApiException):
                raise e
            elif raise_exception:
                err_code = ValidCode.valid_error.value
                error_reason = f"Verification failed, please check whether the parameters are correct: {str(e)}"
                self.log.error(error_reason)
                raise AppApiException(err_code, error_reason)
            return False, "Verification failed"

        return True, "Validation successful"

    def encryption_dict(self, model_info: Dict[str, object]) -> Dict[str, object]:
        """
        Encrypts sensitive information within the model info dictionary.

        :param model_info: A dictionary containing model details, including API keys
        :return: Encrypted model information without sensitive data
        """
        return {
            **model_info,
            "enc_api_key": super().encrypt(model_info.get("api_key", ""))
        }

Changes Made:

  • Removed the unused first comment.
  • Provided clearer docstrings for better understanding.
  • Used snake_case variable names instead of camelCase.
  • Added comments explaining each section of the code.
  • Improved error logging and exception handling by adding context logs.

ds = [d.page_content for d in documents]
result = self.client.rerank(model=self.model, query=query, documents=ds)
return [Document(page_content=d.document.get('text'), metadata={'relevance_score': d.relevance_score}) for d in
result.results]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code appears to be a PyTorch Lightning module that integrates with Cohere's Reranking API. While it looks functional, there are several areas where improvements can be made:

Improvements:

  1. Type Annotations: The BaseDocumentCompressor import is not being used in this class. If you intended to use it, ensure all necessary methods from this interface are implemented.

  2. Error Handling: Adding error handling for exceptions raised during HTTP requests and model inference could improve robustness.

  3. Logging: Consider adding logging statements to track the flow of data through the compress_documents method.

  4. Performance Optimization: For very large documents or queries, consider optimizing how documents are processed before sending them to the server. This might involve chunking long documents or performing partial re-rankings.

  5. Callback Support: Ensure that the callbacks parameter correctly passes through and works within the cohere.ClientV2 interactions.

  6. Testing: Add unit tests to cover various scenarios, such as null input/output, different document types, etc.

  7. Configuration Management: Provide better configuration mechanisms, especially regarding authentication keys and URLs, possibly allowing users to load configurations directly from YAML files or environment variables.

Here’s an updated version with some suggestions:

from typing import Sequence, Optional, Dict, Any

import cohere
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
import logging

logging.basicConfig(level=logging.INFO)

def create_cohere_client(api_key: str, api_url: str) -> cohere.ClientV2:
    return cohere.ClientV2(authentication_token=api_key, base_url=api_url)

MAX_DOCUMENTS = 100  # Limit number of documents sent per request to avoid rate limits


class VllmBgeReranker(MaxKBBaseModel, BaseDocumentCompressor):
    _client = None
    api_key: str
    api_url: str
    model: str
    params: dict

    def __init__(self, api_key: str, api_url: str, model: str, params: dict, callbacks: Callbacks = None):
        super().__init__()
        self.api_key = api_key
        self.model = model
        self.params = params
        self.api_url = api_url
        try:
            logger.info(f"Initializing client with {self.api_key}")
            self._client = create_cohere_client(api_key=self.api_key, api_url=self.api_url)
        except (TypeError, ValueError) as e:
            raise RuntimeError("Invalid credential format", e)

    @staticmethod
    def is_cache_model() -> bool:
        return False

    @staticmethod
    def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
        return VllmBgeReranker(
            model=model_name,
            api_key=model_credential['api_key'],
            api_url=model_credential['api_url'],
            params=model_kwargs,
            **model_kwargs
        )

    def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, batch_size: int = MAX_DOCUMENTS) -> Sequence[Document]:
        """
        Compresses the given documents based on relevancy to the query using Cohere's reranker.
        
        :param documents: A sequence of documents to compress.
        :param query: The search query string.
        :param callbacks: An optional list of callback functions.
        :return: List of compressed documents with relevance scores.
        """
        if documents is None or len(documents) == 0:
            return []

        # Chunk documents into batches to avoid hitting rate limits
        chunks = [documents[i:i + batch_size] for i in range(0, len(documents), batch_size)]
        
        responses = []
        for chunk in chunks:
            try:
                log_info = f"Sending {(len(chunk))} documents for rerank"
                logger.debug(log_info)
                
                documents_texts = [doc.page_content for doc in chunk]
                result = self._client.rerank(
                    model=self.model, 
                    query=query, 
                    documents=documents_texts,
                    custom_parameters=self.params
                )
                
                responses.extend(result.results)
            except Exception as e:
                error_message = f"Failed to process batch {chunk}: {str(e)}"
                logger.error(error_message)
                continue
        
        compressed_docs = [
            Document(page_content=result.document.get('text'), metadata={'relevance_score': result.relevance_score})
            for result in responses
        ]
        
        return compressed_docs

This updated function handles batching of inputs, adds basic error handling, and includes logging. It also allows specification of a batch_size, which helps reduce the impact of rate limits when dealing with large sets of documents.

@zhanweizhang7 zhanweizhang7 merged commit f5fada9 into v2 Aug 21, 2025
3 of 6 checks passed
.append_default_model_info(reranker_model_info_list[0])
.build()
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no significant irregularities or potential issues in this code. The structure of the module appears to be complete, and there is adequate handling of different types of models (LLMs, images, embeddings, STTs) with their respective credentials. Additionally, it looks like a good practice to include both default and primary model information in the model_info_manage object. However, one small suggestion is that if you plan to add more features or expand the codebase, consider using classes instead of individual functions for better organization and encapsulation.

@zhanweizhang7 zhanweizhang7 deleted the pr@v2@feat_vllm_reranker branch August 21, 2025 10:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants