-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Expand file tree
/
Copy pathwhisper_sst.py
More file actions
69 lines (55 loc) · 2.06 KB
/
whisper_sst.py
File metadata and controls
69 lines (55 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import base64
import os
import traceback
from typing import Dict
from openai import OpenAI
from common.utils.logger import maxkb_logger
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_stt import BaseSpeechToText
class VllmWhisperSpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_key: str
api_url: str
model: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.model = kwargs.get('model')
self.params = kwargs.get('params')
self.api_url = kwargs.get('api_url')
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return VllmWhisperSpeechToText(
model=model_name,
api_key=model_credential.get('api_key'),
api_url=model_credential.get('api_url'),
params=model_kwargs,
**model_kwargs
)
def check_auth(self):
cwd = os.path.dirname(os.path.abspath(__file__))
with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as audio_file:
self.speech_to_text(audio_file)
def speech_to_text(self, audio_file):
base_url = self.api_url if self.api_url.endswith('v1') else f"{self.api_url}/v1"
try:
client = OpenAI(
api_key=self.api_key,
base_url=base_url
)
buf = audio_file.read()
filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
transcription_params = {
'model': self.model,
'file': buf,
'language': 'zh',
}
result = client.audio.transcriptions.create(
**transcription_params, extra_body=filter_params
)
return result.text
except Exception as err:
maxkb_logger.error(f":Error: {str(err)}: {traceback.format_exc()}")