-
Notifications
You must be signed in to change notification settings - Fork 2.8k
feat: Vllm reranker model bge reranker v2 m3 #3909
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| import traceback | ||
| from typing import Dict | ||
|
|
||
| from langchain_core.documents import Document | ||
|
|
||
| from common import forms | ||
| from common.exception.app_exception import AppApiException | ||
| from common.forms import BaseForm | ||
| from models_provider.base_model_provider import BaseModelCredential, ValidCode | ||
| from django.utils.translation import gettext_lazy as _ | ||
|
|
||
| from models_provider.impl.vllm_model_provider.model.reranker import VllmBgeReranker | ||
|
|
||
|
|
||
| class VllmRerankerCredential(BaseForm, BaseModelCredential): | ||
| api_url = forms.TextInputField('API URL', required=True) | ||
| api_key = forms.PasswordInputField('API Key', required=True) | ||
|
|
||
| def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, | ||
| raise_exception=True): | ||
| model_type_list = provider.get_model_type_list() | ||
| if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
| raise AppApiException(ValidCode.valid_error.value, | ||
| _('{model_type} Model type is not supported').format(model_type=model_type)) | ||
|
|
||
| for key in ['api_url', 'api_key']: | ||
| if key not in model_credential: | ||
| if raise_exception: | ||
| raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key)) | ||
| else: | ||
| return False | ||
| try: | ||
| model: VllmBgeReranker = provider.get_model(model_type, model_name, model_credential) | ||
| model.compress_documents([Document(page_content=_('Hello'))], _('Hello')) | ||
| except Exception as e: | ||
| traceback.print_exc() | ||
| if isinstance(e, AppApiException): | ||
| raise e | ||
| if raise_exception: | ||
| raise AppApiException( | ||
| ValidCode.valid_error.value, | ||
| _('Verification failed, please check whether the parameters are correct: {error}').format( | ||
| error=str(e)) | ||
| ) | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
| def encryption_dict(self, model_info: Dict[str, object]): | ||
| return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| from typing import Sequence, Optional, Dict, Any | ||
|
|
||
| import cohere | ||
| from langchain_core.callbacks import Callbacks | ||
| from langchain_core.documents import BaseDocumentCompressor, Document | ||
|
|
||
| from models_provider.base_model_provider import MaxKBBaseModel | ||
|
|
||
|
|
||
| class VllmBgeReranker(MaxKBBaseModel, BaseDocumentCompressor): | ||
| api_key: str | ||
| api_url: str | ||
| model: str | ||
| params: dict | ||
| client: Any = None | ||
|
|
||
| def __init__(self, **kwargs): | ||
| super().__init__(**kwargs) | ||
| self.api_key = kwargs.get('api_key') | ||
| self.model = kwargs.get('model') | ||
| self.params = kwargs.get('params') | ||
| self.api_url = kwargs.get('api_url') | ||
| self.client = cohere.ClientV2(kwargs.get('api_key'), base_url=kwargs.get('api_url')) | ||
|
|
||
| @staticmethod | ||
| def is_cache_model(): | ||
| return False | ||
|
|
||
| @staticmethod | ||
| def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): | ||
| return VllmBgeReranker( | ||
| model=model_name, | ||
| api_key=model_credential.get('api_key'), | ||
| api_url=model_credential.get('api_url'), | ||
| params=model_kwargs, | ||
| **model_kwargs | ||
| ) | ||
|
|
||
| def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ | ||
| Sequence[Document]: | ||
| if documents is None or len(documents) == 0: | ||
| return [] | ||
|
|
||
| ds = [d.page_content for d in documents] | ||
| result = self.client.rerank(model=self.model, query=query, documents=ds) | ||
| return [Document(page_content=d.document.get('text'), metadata={'relevance_score': d.relevance_score}) for d in | ||
| result.results] | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The provided code appears to be a PyTorch Lightning module that integrates with Cohere's Reranking API. While it looks functional, there are several areas where improvements can be made: Improvements:
Here’s an updated version with some suggestions: from typing import Sequence, Optional, Dict, Any
import cohere
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
import logging
logging.basicConfig(level=logging.INFO)
def create_cohere_client(api_key: str, api_url: str) -> cohere.ClientV2:
return cohere.ClientV2(authentication_token=api_key, base_url=api_url)
MAX_DOCUMENTS = 100 # Limit number of documents sent per request to avoid rate limits
class VllmBgeReranker(MaxKBBaseModel, BaseDocumentCompressor):
_client = None
api_key: str
api_url: str
model: str
params: dict
def __init__(self, api_key: str, api_url: str, model: str, params: dict, callbacks: Callbacks = None):
super().__init__()
self.api_key = api_key
self.model = model
self.params = params
self.api_url = api_url
try:
logger.info(f"Initializing client with {self.api_key}")
self._client = create_cohere_client(api_key=self.api_key, api_url=self.api_url)
except (TypeError, ValueError) as e:
raise RuntimeError("Invalid credential format", e)
@staticmethod
def is_cache_model() -> bool:
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return VllmBgeReranker(
model=model_name,
api_key=model_credential['api_key'],
api_url=model_credential['api_url'],
params=model_kwargs,
**model_kwargs
)
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, batch_size: int = MAX_DOCUMENTS) -> Sequence[Document]:
"""
Compresses the given documents based on relevancy to the query using Cohere's reranker.
:param documents: A sequence of documents to compress.
:param query: The search query string.
:param callbacks: An optional list of callback functions.
:return: List of compressed documents with relevance scores.
"""
if documents is None or len(documents) == 0:
return []
# Chunk documents into batches to avoid hitting rate limits
chunks = [documents[i:i + batch_size] for i in range(0, len(documents), batch_size)]
responses = []
for chunk in chunks:
try:
log_info = f"Sending {(len(chunk))} documents for rerank"
logger.debug(log_info)
documents_texts = [doc.page_content for doc in chunk]
result = self._client.rerank(
model=self.model,
query=query,
documents=documents_texts,
custom_parameters=self.params
)
responses.extend(result.results)
except Exception as e:
error_message = f"Failed to process batch {chunk}: {str(e)}"
logger.error(error_message)
continue
compressed_docs = [
Document(page_content=result.document.get('text'), metadata={'relevance_score': result.relevance_score})
for result in responses
]
return compressed_docsThis updated function handles batching of inputs, adds basic error handling, and includes logging. It also allows specification of a |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,19 +10,22 @@ | |
| from models_provider.impl.vllm_model_provider.credential.embedding import VllmEmbeddingCredential | ||
| from models_provider.impl.vllm_model_provider.credential.image import VllmImageModelCredential | ||
| from models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential | ||
| from models_provider.impl.vllm_model_provider.credential.reranker import VllmRerankerCredential | ||
| from models_provider.impl.vllm_model_provider.credential.whisper_stt import VLLMWhisperModelCredential | ||
| from models_provider.impl.vllm_model_provider.model.embedding import VllmEmbeddingModel | ||
| from models_provider.impl.vllm_model_provider.model.image import VllmImage | ||
| from models_provider.impl.vllm_model_provider.model.llm import VllmChatModel | ||
| from maxkb.conf import PROJECT_DIR | ||
| from django.utils.translation import gettext as _ | ||
|
|
||
| from models_provider.impl.vllm_model_provider.model.reranker import VllmBgeReranker | ||
| from models_provider.impl.vllm_model_provider.model.whisper_sst import VllmWhisperSpeechToText | ||
|
|
||
| v_llm_model_credential = VLLMModelCredential() | ||
| image_model_credential = VllmImageModelCredential() | ||
| embedding_model_credential = VllmEmbeddingCredential() | ||
| whisper_model_credential = VLLMWhisperModelCredential() | ||
| rerank_model_credential = VllmRerankerCredential() | ||
|
|
||
| model_info_list = [ | ||
| ModelInfo('facebook/opt-125m', _('Facebook’s 125M parameter model'), ModelTypeConst.LLM, v_llm_model_credential, | ||
|
|
@@ -50,6 +53,10 @@ | |
| ModelInfo('whisper-large-v3', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText), | ||
| ] | ||
|
|
||
| reranker_model_info_list = [ | ||
| ModelInfo('bge-reranker-v2-m3', '', ModelTypeConst.RERANKER, rerank_model_credential, VllmBgeReranker), | ||
| ] | ||
|
|
||
| model_info_manage = ( | ||
| ModelInfoManage.builder() | ||
| .append_model_info_list(model_info_list) | ||
|
|
@@ -62,6 +69,8 @@ | |
| .append_default_model_info(embedding_model_info_list[0]) | ||
| .append_model_info_list(whisper_model_info_list) | ||
| .append_default_model_info(whisper_model_info_list[0]) | ||
| .append_model_info_list(reranker_model_info_list) | ||
| .append_default_model_info(reranker_model_info_list[0]) | ||
| .build() | ||
| ) | ||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are no significant irregularities or potential issues in this code. The structure of the module appears to be complete, and there is adequate handling of different types of models (LLMs, images, embeddings, STTs) with their respective credentials. Additionally, it looks like a good practice to include both default and primary model information in the |
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code looks generally good, but there are a few areas where improvements can be made:
Unused Import: The commented-out
@@ -0,0 +1,50 @@line indicates an unused placeholder that should be removed.Docstring Improvements: Consider providing more detailed docstrings for each method to improve readability and understanding.
Variable Names: Variable names like
model_infocould benefit from being more descriptive.Code Formatting: Ensure consistent formatting throughout the file, such as using proper indentation and spacing rules.
Here's an improved version of the code with these suggestions addressed:
Changes Made: