Skip to content

Commit 666065e

Browse files
committed
feat: implement Volcanic Engine Big Model STT client and related classes
1 parent a79f887 commit 666065e

4 files changed

Lines changed: 306 additions & 3 deletions

File tree

apps/models_provider/impl/aliyun_bai_lian_model_provider/model/tti.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def generate_image(self, prompt: str, negative_prompt: str = None):
6262
]
6363
)
6464
rsp = ImageGeneration.call(
65-
model="z-image-turbo",
65+
model=self.model_name,
6666
api_key=self.api_key,
6767
base_url=self.api_base,
6868
messages=[message],
@@ -71,8 +71,8 @@ def generate_image(self, prompt: str, negative_prompt: str = None):
7171
)
7272
file_urls = []
7373
if rsp.status_code == HTTPStatus.OK:
74-
for result in rsp.output.results:
75-
file_urls.append(result.url)
74+
for result in rsp.output.choices:
75+
file_urls.append(result.message.content[0].get('image'))
7676
else:
7777
maxkb_logger.error('sync_call Failed, status_code: %s, code: %s, message: %s' %
7878
(rsp.status_code, rsp.code, rsp.message))
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# coding=utf-8
2+
from typing import Dict
3+
4+
from django.utils.translation import gettext as _
5+
6+
from common import forms
7+
from common.exception.app_exception import AppApiException
8+
from common.forms import BaseForm, TooltipLabel
9+
from models_provider.base_model_provider import BaseModelCredential, ValidCode
10+
from common.utils.logger import maxkb_logger
11+
12+
13+
class VolcanicEngineBigModelSTTModelParams(BaseForm):
14+
uid = forms.TextInputField(
15+
TooltipLabel(_('User ID'), _('If not passed, the default value is streaming_asr_demo')),
16+
required=True,
17+
default_value='streaming_asr_demo'
18+
)
19+
20+
21+
class VolcanicEngineBigModelSTTModelCredential(BaseForm, BaseModelCredential):
22+
volcanic_app_id = forms.TextInputField('App ID', required=True)
23+
volcanic_token = forms.PasswordInputField('Access Token', required=True)
24+
volcanic_api_url = forms.TextInputField('API URL', required=True,
25+
default_value='https://openspeech.bytedance.com/api/v3/auc/bigmodel/recognize/flash')
26+
27+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
28+
raise_exception=False):
29+
model_type_list = provider.get_model_type_list()
30+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
31+
raise AppApiException(ValidCode.valid_error.value,
32+
_('{model_type} Model type is not supported').format(model_type=model_type))
33+
34+
for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token']:
35+
if key not in model_credential:
36+
if raise_exception:
37+
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
38+
else:
39+
return False
40+
try:
41+
model = provider.get_model(model_type, model_name, model_credential, **model_params)
42+
model.check_auth()
43+
except Exception as e:
44+
maxkb_logger.error(f'Exception: {e}', exc_info=True)
45+
if isinstance(e, AppApiException):
46+
raise e
47+
if raise_exception:
48+
raise AppApiException(ValidCode.valid_error.value,
49+
_('Verification failed, please check whether the parameters are correct: {error}').format(
50+
error=str(e)))
51+
else:
52+
return False
53+
return True
54+
55+
def encryption_dict(self, model: Dict[str, object]):
56+
return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))}
57+
58+
def get_model_params_setting_form(self, model_name):
59+
return VolcanicEngineBigModelSTTModelParams()
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# coding=utf-8
2+
3+
"""
4+
requires Python 3.6 or later
5+
6+
pip install asyncio
7+
pip install websockets
8+
"""
9+
10+
import base64
11+
import json
12+
import os
13+
import time
14+
import uuid
15+
import requests
16+
import uuid_utils.compat as uuid
17+
18+
from typing import Dict
19+
20+
from common.utils.logger import maxkb_logger
21+
from models_provider.base_model_provider import MaxKBBaseModel
22+
from models_provider.impl.base_stt import BaseSpeechToText
23+
24+
audio_format = "mp3" # wav 或者 mp3,根据实际音频格式设置
25+
26+
27+
def determine_api_mode(url):
28+
"""
29+
根据URL判断API模式
30+
"""
31+
if '/recognize/flash' in url:
32+
return 'sync'
33+
elif '/submit' in url:
34+
return 'async_submit'
35+
elif '/query' in url:
36+
return 'async_query'
37+
else:
38+
return 'unknown'
39+
40+
41+
class VolcanicASRClient:
42+
def __init__(self, appid, token):
43+
self.appid = appid
44+
self.token = token
45+
46+
def _build_headers(self, url, task_id=None, x_tt_logid=None):
47+
"""根据URL构建请求头"""
48+
mode = determine_api_mode(url)
49+
50+
headers = {
51+
"X-Api-App-Key": self.appid,
52+
"X-Api-Access-Key": self.token,
53+
}
54+
55+
if mode == 'sync':
56+
headers.update({
57+
"X-Api-Resource-Id": "volc.bigasr.auc_turbo",
58+
"X-Api-Request-Id": str(uuid.uuid4()),
59+
"X-Api-Sequence": "-1",
60+
})
61+
elif mode == 'async_submit':
62+
headers.update({
63+
"X-Api-Resource-Id": "volc.bigasr.auc",
64+
"X-Api-Request-Id": task_id or str(uuid.uuid4()),
65+
"X-Api-Sequence": "-1",
66+
})
67+
elif mode == 'async_query':
68+
headers.update({
69+
"X-Api-Resource-Id": "volc.bigasr.auc",
70+
"X-Api-Request-Id": task_id or str(uuid.uuid4()),
71+
"X-Tt-Logid": x_tt_logid or "",
72+
})
73+
74+
return headers
75+
76+
def _create_request_body(self, audio_data, mode='sync'):
77+
"""创建请求体"""
78+
base_request = {
79+
"user": {"uid": self.appid if mode == 'sync' else "fake_uid"},
80+
"audio": audio_data,
81+
}
82+
83+
if mode == 'sync':
84+
base_request["request"] = {
85+
"model_name": "bigmodel",
86+
"enable_itn": True,
87+
"enable_punc": True,
88+
"enable_ddc": True,
89+
}
90+
else: # async
91+
base_request["request"] = {
92+
"model_name": "bigmodel",
93+
"enable_channel_split": True,
94+
"enable_ddc": True,
95+
"enable_speaker_info": True,
96+
"enable_punc": True,
97+
"enable_itn": True,
98+
"corpus": {
99+
"correct_table_name": "",
100+
"context": ""
101+
}
102+
}
103+
104+
return base_request
105+
106+
def process_audio(self, audio_file=None, submit_url=None):
107+
"""
108+
根据submit_url自动选择处理模式
109+
"""
110+
# 获取音频数据
111+
base64_audio = base64.b64encode(audio_file.read()).decode("utf-8")
112+
audio_data = {"data": base64_audio}
113+
114+
# 根据URL判断API模式
115+
mode = determine_api_mode(submit_url)
116+
117+
if mode == 'sync':
118+
return self._sync_recognize(audio_data, submit_url)
119+
elif mode == 'async_submit':
120+
return self._async_process(audio_data, submit_url)
121+
else:
122+
raise ValueError(f"Unsupported URL pattern: {submit_url}")
123+
124+
def _get_audio_data(self, audio_file):
125+
"""构建音频数据对象"""
126+
base64_audio = base64.b64encode(audio_file.read()).decode("utf-8")
127+
return {"data": base64_audio}
128+
129+
def _sync_recognize(self, audio_data, submit_url):
130+
"""同步识别模式"""
131+
headers = self._build_headers(submit_url)
132+
request_body = self._create_request_body(audio_data, mode='sync')
133+
134+
response = requests.post(submit_url, json=request_body, headers=headers)
135+
return self._handle_response(response, "sync_recognize")
136+
137+
def _async_process(self, audio_data, submit_url):
138+
"""异步处理模式"""
139+
# 提交任务
140+
task_id = str(uuid.uuid4())
141+
headers = self._build_headers(submit_url, task_id=task_id)
142+
request_body = self._create_request_body(audio_data, mode='async')
143+
144+
submit_response = requests.post(submit_url, data=json.dumps(request_body), headers=headers)
145+
146+
if submit_response.headers.get("X-Api-Status-Code") == "20000000":
147+
x_tt_logid = submit_response.headers.get("X-Tt-Logid", "")
148+
# 查询结果
149+
return self._poll_for_result(task_id, x_tt_logid)
150+
else:
151+
print(f"Submit task failed: {submit_response.headers}")
152+
return None
153+
154+
def _poll_for_result(self, task_id, x_tt_logid):
155+
"""轮询查询异步任务结果"""
156+
query_url = "https://openspeech-direct.zijieapi.com/api/v3/auc/bigmodel/query"
157+
158+
while True:
159+
query_response = self._query_task(task_id, x_tt_logid, query_url)
160+
code = query_response.headers.get('X-Api-Status-Code', "")
161+
162+
if code == '20000000': # 任务完成
163+
return query_response
164+
elif code != '20000001' and code != '20000002': # 任务失败
165+
print(f"Async task failed with code: {code}")
166+
return None
167+
time.sleep(1)
168+
169+
def _query_task(self, task_id, x_tt_logid, query_url):
170+
"""执行单次查询请求"""
171+
headers = self._build_headers(query_url, task_id=task_id, x_tt_logid=x_tt_logid)
172+
response = requests.post(query_url, json.dumps({}), headers=headers)
173+
return self._handle_response(response, "async_query", silent=True)
174+
175+
def _handle_response(self, response, operation, silent=False):
176+
"""处理响应"""
177+
if 'X-Api-Status-Code' in response.headers:
178+
if not silent:
179+
print(f'{operation} response header X-Api-Status-Code: {response.headers["X-Api-Status-Code"]}')
180+
print(f'{operation} response header X-Api-Message: {response.headers["X-Api-Message"]}')
181+
print(f'{operation} response header X-Tt-Logid: {response.headers["X-Tt-Logid"]}')
182+
183+
if operation == "sync_recognize":
184+
print(f'sync response content: {response.json()}\n')
185+
186+
return response
187+
else:
188+
print(f'{operation} failed: {response.headers}\n')
189+
return None
190+
191+
192+
class VolcanicEngineBigModelSpeechToText(MaxKBBaseModel, BaseSpeechToText):
193+
volcanic_app_id: str
194+
volcanic_api_url: str
195+
volcanic_token: str
196+
params: dict
197+
198+
def __init__(self, **kwargs):
199+
super().__init__(**kwargs)
200+
self.volcanic_api_url = kwargs.get('volcanic_api_url')
201+
self.volcanic_token = kwargs.get('volcanic_token')
202+
self.volcanic_app_id = kwargs.get('volcanic_app_id')
203+
self.params = kwargs.get('params')
204+
205+
@staticmethod
206+
def is_cache_model():
207+
return False
208+
209+
@staticmethod
210+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
211+
optional_params = {}
212+
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
213+
optional_params['max_tokens'] = model_kwargs['max_tokens']
214+
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
215+
optional_params['temperature'] = model_kwargs['temperature']
216+
return VolcanicEngineBigModelSpeechToText(
217+
volcanic_api_url=model_credential.get('volcanic_api_url'),
218+
volcanic_token=model_credential.get('volcanic_token'),
219+
volcanic_app_id=model_credential.get('volcanic_app_id'),
220+
params=model_kwargs,
221+
)
222+
223+
def check_auth(self):
224+
cwd = os.path.dirname(os.path.abspath(__file__))
225+
with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as audio_file:
226+
self.speech_to_text(audio_file)
227+
228+
def speech_to_text(self, audio_file):
229+
try:
230+
client = VolcanicASRClient(self.volcanic_app_id, self.volcanic_token)
231+
result = client.process_audio(audio_file, self.volcanic_api_url)
232+
if result.status_code == 200:
233+
return result.json().get('result').get('text')
234+
except Exception as e:
235+
maxkb_logger.error(f'Error getting speech to text: {e}')

apps/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@
1212
from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
1313
ModelInfoManage
1414
from models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
15+
from models_provider.impl.volcanic_engine_model_provider.credential.bigModel_stt import \
16+
VolcanicEngineBigModelSTTModelCredential
1517
from models_provider.impl.volcanic_engine_model_provider.credential.embedding import VolcanicEmbeddingCredential
1618
from models_provider.impl.volcanic_engine_model_provider.credential.image import \
1719
VolcanicEngineImageModelCredential
1820
from models_provider.impl.volcanic_engine_model_provider.credential.tti import VolcanicEngineTTIModelCredential
1921
from models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential
2022
from models_provider.impl.volcanic_engine_model_provider.credential.ttv import VolcanicEngineTTVModelCredential
23+
from models_provider.impl.volcanic_engine_model_provider.model.bigModel_stt import VolcanicEngineBigModelSpeechToText
2124
from models_provider.impl.volcanic_engine_model_provider.model.embedding import VolcanicEngineEmbeddingModel
2225
from models_provider.impl.volcanic_engine_model_provider.model.image import VolcanicEngineImage
2326
from models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
@@ -33,6 +36,7 @@
3336

3437
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
3538
volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential()
39+
volcanic_engine_big_stt_model_credential = VolcanicEngineBigModelSTTModelCredential()
3640
volcanic_engine_tts_model_credential = VolcanicEngineTTSModelCredential()
3741
volcanic_engine_image_model_credential = VolcanicEngineImageModelCredential()
3842
volcanic_engine_tti_model_credential = VolcanicEngineTTIModelCredential()
@@ -53,6 +57,11 @@
5357
ModelTypeConst.STT,
5458
volcanic_engine_stt_model_credential, VolcanicEngineSpeechToText
5559
),
60+
ModelInfo('bigmodel',
61+
'',
62+
ModelTypeConst.STT,
63+
volcanic_engine_big_stt_model_credential, VolcanicEngineBigModelSpeechToText
64+
),
5665
ModelInfo('tts',
5766
'',
5867
ModelTypeConst.TTS,

0 commit comments

Comments
 (0)