-
Notifications
You must be signed in to change notification settings - Fork 2.8k
feat: Vllm whisper model #3901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Vllm whisper model #3901
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # coding=utf-8 | ||
| import traceback | ||
| from typing import Dict | ||
|
|
||
| from django.utils.translation import gettext_lazy as _, gettext | ||
| from langchain_core.messages import HumanMessage | ||
|
|
||
| from common import forms | ||
| from common.exception.app_exception import AppApiException | ||
| from common.forms import BaseForm, TooltipLabel | ||
| from models_provider.base_model_provider import BaseModelCredential, ValidCode | ||
|
|
||
|
|
||
| class VLLMWhisperModelParams(BaseForm): | ||
| Language = forms.TextInputField( | ||
| TooltipLabel(_('Language'), | ||
| _("If not passed, the default value is 'zh'")), | ||
| required=True, | ||
| default_value='zh', | ||
| ) | ||
|
|
||
|
|
||
| class VLLMWhisperModelCredential(BaseForm, BaseModelCredential): | ||
| api_url = forms.TextInputField('API URL', required=True) | ||
| api_key = forms.PasswordInputField('API Key', required=True) | ||
|
|
||
| def is_valid(self, | ||
| model_type: str, | ||
| model_name, | ||
| model_credential: Dict[str, object], | ||
| model_params, | ||
| provider, | ||
| raise_exception=False): | ||
|
|
||
| model_type_list = provider.get_model_type_list() | ||
|
|
||
| if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
| raise AppApiException(ValidCode.valid_error.value, | ||
| gettext('{model_type} Model type is not supported').format(model_type=model_type)) | ||
| try: | ||
| model_list = provider.get_base_model_list(model_credential.get('api_url'), model_credential.get('api_key')) | ||
| except Exception as e: | ||
| raise AppApiException(ValidCode.valid_error.value, gettext('API domain name is invalid')) | ||
| exist = provider.get_model_info_by_name(model_list, model_name) | ||
| if len(exist) == 0: | ||
| raise AppApiException(ValidCode.valid_error.value, | ||
| gettext('The model does not exist, please download the model first')) | ||
| model = provider.get_model(model_type, model_name, model_credential, **model_params) | ||
| return True | ||
|
|
||
| def encryption_dict(self, model_info: Dict[str, object]): | ||
| return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} | ||
|
|
||
| def build_model(self, model_info: Dict[str, object]): | ||
| for key in ['api_key', 'model']: | ||
| if key not in model_info: | ||
| raise AppApiException(500, gettext('{key} is required').format(key=key)) | ||
| self.api_key = model_info.get('api_key') | ||
| return self | ||
|
|
||
| def get_model_params_setting_form(self, model_name): | ||
| return VLLMWhisperModelParams() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import base64 | ||
| import os | ||
| import traceback | ||
| from typing import Dict | ||
|
|
||
| from openai import OpenAI | ||
|
|
||
| from common.utils.logger import maxkb_logger | ||
| from models_provider.base_model_provider import MaxKBBaseModel | ||
| from models_provider.impl.base_stt import BaseSpeechToText | ||
|
|
||
|
|
||
|
|
||
| class VllmWhisperSpeechToText(MaxKBBaseModel, BaseSpeechToText): | ||
| api_key: str | ||
| api_url: str | ||
| model: str | ||
| params: dict | ||
|
|
||
| def __init__(self, **kwargs): | ||
| super().__init__(**kwargs) | ||
| self.api_key = kwargs.get('api_key') | ||
| self.model = kwargs.get('model') | ||
| self.params = kwargs.get('params') | ||
| self.api_url = kwargs.get('api_url') | ||
|
|
||
| @staticmethod | ||
| def is_cache_model(): | ||
| return False | ||
|
|
||
| @staticmethod | ||
| def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): | ||
| return VllmWhisperSpeechToText( | ||
| model=model_name, | ||
| api_key=model_credential.get('api_key'), | ||
| api_url=model_credential.get('api_url'), | ||
| params=model_kwargs, | ||
| **model_kwargs | ||
| ) | ||
|
|
||
| def check_auth(self): | ||
| cwd = os.path.dirname(os.path.abspath(__file__)) | ||
| with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as audio_file: | ||
| self.speech_to_text(audio_file) | ||
|
|
||
| def speech_to_text(self, audio_file): | ||
| base_url = f"{self.api_url}/v1" | ||
| try: | ||
| client = OpenAI( | ||
| api_key=self.api_key, | ||
| base_url=base_url | ||
| ) | ||
|
|
||
| result = client.audio.transcriptions.create( | ||
| file=audio_file, | ||
| model=self.model, | ||
| language=self.params.get('Language'), | ||
| response_format="json" | ||
| ) | ||
|
|
||
| return result.text | ||
|
|
||
| except Exception as err: | ||
| maxkb_logger.error(f":Error: {str(err)}: {traceback.format_exc()}") | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code appears to be an implementation of a speech-to-text service using the VLLM Whisper model through OpenAI's API. Here are some observations and potential areas for improvement: Observations:
Potential Improvements:
Here is a refined version of the code addressing some of these points: # Imports
import os
import traceback
from typing import Dict
from openai import OpenAI
import logging
from common.utils.logger import maxkb_logger
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_stt import BaseSpeechToText
# Set up logging configuration
logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)
class VllmWhisperSpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_key: str
api_url: str
model: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.model = kwargs.get('model')
self.params = kwargs.get('params')
self.api_url = kwargs.get('api_url')
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return VllmWhisperSpeechToText(
model=model_name,
api_key=model_credential.get('api_key'),
api_url=model_credential.get('api_url'),
params=model_kwargs,
**model_kwargs
)
def check_auth(self):
cwd = os.path.dirname(os.path.abspath(__file__))
try:
# Simulate reading audio file for test purpose
with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as audio_file:
self.speech_to_text(audio_file)
except FileNotFoundError as e:
logger.error(f"FileNotFoundError: {str(e)}", exc_info=True)
def speech_to_text(self, audio_file) -> str:
"""
Convert audio file to text using VLLM Whisper model.
:param audio_file: Audio file bytes
:return: Transcribed text
"""
base_url = f"{self.api_url}/v1"
try:
client = OpenAI(api_key=self.api_key, base_url=base_url)
result = client.audio.transcriptions.create(
file=audio_file,
model=self.model,
language=self.params.get('language'), # Corrected parameter name
response_format="json"
)
return result.text
except Exception as err:
logger.error(f"An error occurred during transcription: {str(err)}")
return None # Return None instead of empty stringThese changes improve the structure, enhance logging, clarify method names, and add basic error handling for the |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,20 +10,27 @@ | |
| from models_provider.impl.vllm_model_provider.credential.embedding import VllmEmbeddingCredential | ||
| from models_provider.impl.vllm_model_provider.credential.image import VllmImageModelCredential | ||
| from models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential | ||
| from models_provider.impl.vllm_model_provider.credential.whisper_stt import VLLMWhisperModelCredential | ||
| from models_provider.impl.vllm_model_provider.model.embedding import VllmEmbeddingModel | ||
| from models_provider.impl.vllm_model_provider.model.image import VllmImage | ||
| from models_provider.impl.vllm_model_provider.model.llm import VllmChatModel | ||
| from maxkb.conf import PROJECT_DIR | ||
| from django.utils.translation import gettext as _ | ||
|
|
||
| from models_provider.impl.vllm_model_provider.model.whisper_sst import VllmWhisperSpeechToText | ||
|
|
||
| v_llm_model_credential = VLLMModelCredential() | ||
| image_model_credential = VllmImageModelCredential() | ||
| embedding_model_credential = VllmEmbeddingCredential() | ||
| whisper_model_credential = VLLMWhisperModelCredential() | ||
|
|
||
| model_info_list = [ | ||
| ModelInfo('facebook/opt-125m', _('Facebook’s 125M parameter model'), ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), | ||
| ModelInfo('BAAI/Aquila-7B', _('BAAI’s 7B parameter model'), ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), | ||
| ModelInfo('BAAI/AquilaChat-7B', _('BAAI’s 13B parameter mode'), ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), | ||
| ModelInfo('facebook/opt-125m', _('Facebook’s 125M parameter model'), ModelTypeConst.LLM, v_llm_model_credential, | ||
| VllmChatModel), | ||
| ModelInfo('BAAI/Aquila-7B', _('BAAI’s 7B parameter model'), ModelTypeConst.LLM, v_llm_model_credential, | ||
| VllmChatModel), | ||
| ModelInfo('BAAI/AquilaChat-7B', _('BAAI’s 13B parameter mode'), ModelTypeConst.LLM, v_llm_model_credential, | ||
| VllmChatModel), | ||
|
|
||
| ] | ||
|
|
||
|
|
@@ -32,7 +39,15 @@ | |
| ] | ||
|
|
||
| embedding_model_info_list = [ | ||
| ModelInfo('HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5', '', ModelTypeConst.EMBEDDING, embedding_model_credential, VllmEmbeddingModel), | ||
| ModelInfo('HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5', '', ModelTypeConst.EMBEDDING, | ||
| embedding_model_credential, VllmEmbeddingModel), | ||
| ] | ||
|
|
||
| whisper_model_info_list = [ | ||
| ModelInfo('whisper-tiny', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText), | ||
| ModelInfo('whisper-large-v3-turbo', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText), | ||
| ModelInfo('whisper-small', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText), | ||
| ModelInfo('whisper-large-v3', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText), | ||
| ] | ||
|
|
||
| model_info_manage = ( | ||
|
|
@@ -45,6 +60,8 @@ | |
| .append_default_model_info(image_model_info_list[0]) | ||
| .append_model_info_list(embedding_model_info_list) | ||
| .append_default_model_info(embedding_model_info_list[0]) | ||
| .append_model_info_list(whisper_model_info_list) | ||
| .append_default_model_info(whisper_model_info_list[0]) | ||
| .build() | ||
| ) | ||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code looks mostly correct and should work without significant issues. However, there are a few areas that could be improved:
Here's an updated version with some optimizations: from models_provider.impl.vllm_model_provider.credential.embedding import VllmEmbeddingCredential
from models_provider.impl.vllm_model_provider.credential.image import VllmImageModelCredential
from models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential
from models_provider.impl.vllm_model_provider.credential.whisper_stt import VLLMWhisperModelCredential
from models_provider.impl.vllm_model_provider.model.embedding import VllmEmbeddingModel
from models_provider.impl.vllm_model_provider.model.image import VllmImage
from models_provider.impl.vllm_model_provider.model.llm import VllmChatModel
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _
v_llm_model_credential = VLLMModelCredential()
image_model_credential = VllmImageModelCredential()
embedding_model_credential = VllmEmbeddingCredential()
whisper_model_credential = VLLMWhisperModelCredential()
model_info_lists = [
(VLLMChatModel, v_llm_model_credential),
(VllmEmbeddingModel, embedding_model_credential),
(VllmWhisperSpeechToText, whisper_model_credential)
]
all_models = (
image_model_info_list +
embedding_model_info_list +
whisper_model_info_list
)
# Assuming append_default_model_info handles adding the first element twice if needed
config_management.append_model_info_list(all_models).build()Changes Made:
These changes aim to make the code cleaner and potentially more efficient by avoiding redundancy in the model info lists. |
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No irregularities found. The code looks to be correctly structured for a Django-based form handling with model credential validation. Here are some optimizations you might consider:
Use
gettextdirectly: Since both_('Language')andgettext('{model_type} Model type is not supported').format(model_type=model_type)usegettext, they can be consolidated into a single call.Remove unnecessary imports:
langchain_core.messages.HumanMessageis used but never referenced within this class, so it's safe to remove from the imports list.Encapsulate logic: You could encapsulate some of the exception handling and message formatting in helper functions rather than repeating them across lines.
Consider using context managers: If you anticipate making multiple network requests or database interactions, using async contexts (Python 3.7+) would help manage operations more cleanly.
Here's an updated version with these considerations:
This version introduces helpers like
_create_error_msgfor better readability and consolidates duplicated message creation patterns.