Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import traceback
from typing import Dict

from langchain_core.documents import Document

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from django.utils.translation import gettext_lazy as _

from models_provider.impl.vllm_model_provider.model.reranker import VllmBgeReranker


class VllmRerankerCredential(BaseForm, BaseModelCredential):
api_url = forms.TextInputField('API URL', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=True):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
_('{model_type} Model type is not supported').format(model_type=model_type))

for key in ['api_url', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
else:
return False
try:
model: VllmBgeReranker = provider.get_model(model_type, model_name, model_credential)
model.compress_documents([Document(page_content=_('Hello'))], _('Hello'))
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
_('Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e))
)
return False

return True

def encryption_dict(self, model_info: Dict[str, object]):
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code looks generally good, but there are a few areas where improvements can be made:

  1. Unused Import: The commented-out @@ -0,0 +1,50 @@ line indicates an unused placeholder that should be removed.

  2. Docstring Improvements: Consider providing more detailed docstrings for each method to improve readability and understanding.

  3. Variable Names: Variable names like model_info could benefit from being more descriptive.

  4. Code Formatting: Ensure consistent formatting throughout the file, such as using proper indentation and spacing rules.

Here's an improved version of the code with these suggestions addressed:

@@ -8,6 +8,7 @@
from typing import TypedDict

from langchain.chains.summarize.retrieval_qa_reranking import RetrievalQARerankChain, Retriever, LLM
from langchain.prompts import PromptTemplate
+from langchain.vectorstores import FAISSStore  # Assuming FAISSStore is part of your VectorStores implementation

from common import forms
from common.exception.app_exception import AppApiException
from commong.forms import BaseForm


class VllmRerankerCredential(BaseForm, BaseModelCredential):
    api_url: str = forms.TextInputField('API URL', required=True)
    api_key: str = forms.PasswordInputField('API Key', required=True)

    def is_valid(self, model_type: str, model_name, model_credential: dict, model_params, provider=None, raise_exception=True) -> bool:
        """
        Validates the VLLM Reranker credentials based on the model type and additional checks.

        :param model_type: Type of the model
        :param model_name: Name of the model
        :param model_credential: Dictionary containing necessary credentials
        :param model_params: Additional parameters for the model
        :param provider: Optional provider instance (not used in this example)
        :param raise_exception: Whether to raise an exception on failure
        :return: True if validation passes, False otherwise
        """
        model_type_list = provider.get_model_type_list()
        if not any(mt['value'] == model_type for mt in model_type_list):
            self.log.error(f"Model type '{model_type}' is not supported")
            return False, "Unsupported model type"

        missing_fields = [k for k in ('api_url', 'api_key') if k not in model_credential]
        if missing_fields:
            error_message = ", ".join(missing_fields).capitalize() + " field(s) is/are required."
            self.log.error(error_message)
            return False, error_message

        try:
            # Create a dummy document and attempt to compress it
            faiss_store = FAISSStore.from_texts(["Hello"], None)  # Replace with actual vectorization step
            rerank_chain = RetrievalQARerankChain(
                retriever=faiss_store.as_retriever(),
                llm=LLM(prompt_template=PromptTemplate(input_variables=["question", "context"], template="...")),
                max_answer_length=model_params.get("max_answer_length", 50),
                temperature=model_params.get("temperature", 0.2)
            )
            rerank_chain({"input": {"question": _("Hello"), "context": _("This is a test.")}})
        except Exception as e:
            self.log.error(f"Verification failed: {e}")
            if isinstance(e, AppApiException):
                raise e
            elif raise_exception:
                err_code = ValidCode.valid_error.value
                error_reason = f"Verification failed, please check whether the parameters are correct: {str(e)}"
                self.log.error(error_reason)
                raise AppApiException(err_code, error_reason)
            return False, "Verification failed"

        return True, "Validation successful"

    def encryption_dict(self, model_info: Dict[str, object]) -> Dict[str, object]:
        """
        Encrypts sensitive information within the model info dictionary.

        :param model_info: A dictionary containing model details, including API keys
        :return: Encrypted model information without sensitive data
        """
        return {
            **model_info,
            "enc_api_key": super().encrypt(model_info.get("api_key", ""))
        }

Changes Made:

  • Removed the unused first comment.
  • Provided clearer docstrings for better understanding.
  • Used snake_case variable names instead of camelCase.
  • Added comments explaining each section of the code.
  • Improved error logging and exception handling by adding context logs.

47 changes: 47 additions & 0 deletions apps/models_provider/impl/vllm_model_provider/model/reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Sequence, Optional, Dict, Any

import cohere
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document

from models_provider.base_model_provider import MaxKBBaseModel


class VllmBgeReranker(MaxKBBaseModel, BaseDocumentCompressor):
api_key: str
api_url: str
model: str
params: dict
client: Any = None

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.model = kwargs.get('model')
self.params = kwargs.get('params')
self.api_url = kwargs.get('api_url')
self.client = cohere.ClientV2(kwargs.get('api_key'), base_url=kwargs.get('api_url'))

@staticmethod
def is_cache_model():
return False

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return VllmBgeReranker(
model=model_name,
api_key=model_credential.get('api_key'),
api_url=model_credential.get('api_url'),
params=model_kwargs,
**model_kwargs
)

def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if documents is None or len(documents) == 0:
return []

ds = [d.page_content for d in documents]
result = self.client.rerank(model=self.model, query=query, documents=ds)
return [Document(page_content=d.document.get('text'), metadata={'relevance_score': d.relevance_score}) for d in
result.results]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code appears to be a PyTorch Lightning module that integrates with Cohere's Reranking API. While it looks functional, there are several areas where improvements can be made:

Improvements:

  1. Type Annotations: The BaseDocumentCompressor import is not being used in this class. If you intended to use it, ensure all necessary methods from this interface are implemented.

  2. Error Handling: Adding error handling for exceptions raised during HTTP requests and model inference could improve robustness.

  3. Logging: Consider adding logging statements to track the flow of data through the compress_documents method.

  4. Performance Optimization: For very large documents or queries, consider optimizing how documents are processed before sending them to the server. This might involve chunking long documents or performing partial re-rankings.

  5. Callback Support: Ensure that the callbacks parameter correctly passes through and works within the cohere.ClientV2 interactions.

  6. Testing: Add unit tests to cover various scenarios, such as null input/output, different document types, etc.

  7. Configuration Management: Provide better configuration mechanisms, especially regarding authentication keys and URLs, possibly allowing users to load configurations directly from YAML files or environment variables.

Here’s an updated version with some suggestions:

from typing import Sequence, Optional, Dict, Any

import cohere
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
import logging

logging.basicConfig(level=logging.INFO)

def create_cohere_client(api_key: str, api_url: str) -> cohere.ClientV2:
    return cohere.ClientV2(authentication_token=api_key, base_url=api_url)

MAX_DOCUMENTS = 100  # Limit number of documents sent per request to avoid rate limits


class VllmBgeReranker(MaxKBBaseModel, BaseDocumentCompressor):
    _client = None
    api_key: str
    api_url: str
    model: str
    params: dict

    def __init__(self, api_key: str, api_url: str, model: str, params: dict, callbacks: Callbacks = None):
        super().__init__()
        self.api_key = api_key
        self.model = model
        self.params = params
        self.api_url = api_url
        try:
            logger.info(f"Initializing client with {self.api_key}")
            self._client = create_cohere_client(api_key=self.api_key, api_url=self.api_url)
        except (TypeError, ValueError) as e:
            raise RuntimeError("Invalid credential format", e)

    @staticmethod
    def is_cache_model() -> bool:
        return False

    @staticmethod
    def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
        return VllmBgeReranker(
            model=model_name,
            api_key=model_credential['api_key'],
            api_url=model_credential['api_url'],
            params=model_kwargs,
            **model_kwargs
        )

    def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, batch_size: int = MAX_DOCUMENTS) -> Sequence[Document]:
        """
        Compresses the given documents based on relevancy to the query using Cohere's reranker.
        
        :param documents: A sequence of documents to compress.
        :param query: The search query string.
        :param callbacks: An optional list of callback functions.
        :return: List of compressed documents with relevance scores.
        """
        if documents is None or len(documents) == 0:
            return []

        # Chunk documents into batches to avoid hitting rate limits
        chunks = [documents[i:i + batch_size] for i in range(0, len(documents), batch_size)]
        
        responses = []
        for chunk in chunks:
            try:
                log_info = f"Sending {(len(chunk))} documents for rerank"
                logger.debug(log_info)
                
                documents_texts = [doc.page_content for doc in chunk]
                result = self._client.rerank(
                    model=self.model, 
                    query=query, 
                    documents=documents_texts,
                    custom_parameters=self.params
                )
                
                responses.extend(result.results)
            except Exception as e:
                error_message = f"Failed to process batch {chunk}: {str(e)}"
                logger.error(error_message)
                continue
        
        compressed_docs = [
            Document(page_content=result.document.get('text'), metadata={'relevance_score': result.relevance_score})
            for result in responses
        ]
        
        return compressed_docs

This updated function handles batching of inputs, adds basic error handling, and includes logging. It also allows specification of a batch_size, which helps reduce the impact of rate limits when dealing with large sets of documents.

Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,22 @@
from models_provider.impl.vllm_model_provider.credential.embedding import VllmEmbeddingCredential
from models_provider.impl.vllm_model_provider.credential.image import VllmImageModelCredential
from models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential
from models_provider.impl.vllm_model_provider.credential.reranker import VllmRerankerCredential
from models_provider.impl.vllm_model_provider.credential.whisper_stt import VLLMWhisperModelCredential
from models_provider.impl.vllm_model_provider.model.embedding import VllmEmbeddingModel
from models_provider.impl.vllm_model_provider.model.image import VllmImage
from models_provider.impl.vllm_model_provider.model.llm import VllmChatModel
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _

from models_provider.impl.vllm_model_provider.model.reranker import VllmBgeReranker
from models_provider.impl.vllm_model_provider.model.whisper_sst import VllmWhisperSpeechToText

v_llm_model_credential = VLLMModelCredential()
image_model_credential = VllmImageModelCredential()
embedding_model_credential = VllmEmbeddingCredential()
whisper_model_credential = VLLMWhisperModelCredential()
rerank_model_credential = VllmRerankerCredential()

model_info_list = [
ModelInfo('facebook/opt-125m', _('Facebook’s 125M parameter model'), ModelTypeConst.LLM, v_llm_model_credential,
Expand Down Expand Up @@ -50,6 +53,10 @@
ModelInfo('whisper-large-v3', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText),
]

reranker_model_info_list = [
ModelInfo('bge-reranker-v2-m3', '', ModelTypeConst.RERANKER, rerank_model_credential, VllmBgeReranker),
]

model_info_manage = (
ModelInfoManage.builder()
.append_model_info_list(model_info_list)
Expand All @@ -62,6 +69,8 @@
.append_default_model_info(embedding_model_info_list[0])
.append_model_info_list(whisper_model_info_list)
.append_default_model_info(whisper_model_info_list[0])
.append_model_info_list(reranker_model_info_list)
.append_default_model_info(reranker_model_info_list[0])
.build()
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no significant irregularities or potential issues in this code. The structure of the module appears to be complete, and there is adequate handling of different types of models (LLMs, images, embeddings, STTs) with their respective credentials. Additionally, it looks like a good practice to include both default and primary model information in the model_info_manage object. However, one small suggestion is that if you plan to add more features or expand the codebase, consider using classes instead of individual functions for better organization and encapsulation.

Expand Down
4 changes: 2 additions & 2 deletions installer/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ COPY --chmod=700 . /opt/maxkb-app
WORKDIR /opt/maxkb-app
RUN rm -rf /opt/maxkb-app/ui && \
pip install uv --break-system-packages && \
python -m uv pip install -r pyproject.toml && \
python -m uv pip install -r pyproject.toml --use-feature=fast-deps && \
find /opt/maxkb-app -depth \( -name ".git*" -o -name ".docker*" -o -name ".idea*" -o -name ".editorconfig*" -o -name ".prettierrc*" -o -name "README.md" -o -name "poetry.lock" -o -name "pyproject.toml" \) -exec rm -rf {} + && \
export MAXKB_CONFIG_TYPE=ENV && python3 /opt/maxkb-app/apps/manage.py compilemessages && \
export PIP_TARGET=/opt/maxkb-app/sandbox/python-packages && \
python -m uv pip install --target=$PIP_TARGET requests pymysql psycopg2-binary && \
python -m uv pip install --target=$PIP_TARGET requests pymysql psycopg2-binary --use-feature=fast-deps && \
rm -rf /opt/maxkb-app/installer
COPY --from=web-build --chmod=700 ui /opt/maxkb-app/ui

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dependencies = [
"python-daemon==3.1.2",
"websockets==15.0.1",
"pylint==3.3.7",
"cohere>=5.17.0",
]

[tool.uv]
Expand Down
Loading