Skip to content

Commit 523eefd

Browse files
RhoninRhonin
authored andcommitted
feat(provider): add StepFun ASR provider
1 parent 2d78626 commit 523eefd

4 files changed

Lines changed: 359 additions & 0 deletions

File tree

astrbot/core/config/default.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,6 +1557,20 @@
15571557
"timeout": "20",
15581558
"proxy": "",
15591559
},
1560+
"StepFun ASR(API)": {
1561+
"id": "stepfun_asr",
1562+
"provider": "stepfun",
1563+
"type": "stepfun_asr",
1564+
"provider_type": "speech_to_text",
1565+
"enable": False,
1566+
"api_key": "",
1567+
"api_base": "https://api.stepfun.com/step_plan/v1",
1568+
"model": "stepaudio-2.5-asr",
1569+
"stepfun-asr-language": "zh",
1570+
"stepfun-asr-enable-itn": True,
1571+
"timeout": "20",
1572+
"proxy": "",
1573+
},
15601574
"Whisper(Local)": {
15611575
"provider": "openai",
15621576
"type": "openai_whisper_selfhost",
@@ -2547,6 +2561,16 @@
25472561
"type": "string",
25482562
"hint": "附加给 MiMo STT 的用户提示词,用于约束返回结果格式。",
25492563
},
2564+
"stepfun-asr-language": {
2565+
"description": "语言",
2566+
"type": "string",
2567+
"hint": "StepFun ASR 的识别语言。默认 zh。",
2568+
},
2569+
"stepfun-asr-enable-itn": {
2570+
"description": "数字规整",
2571+
"type": "bool",
2572+
"hint": "是否启用 StepFun ASR 的 ITN 数字规整。",
2573+
},
25502574
"openai-tts-voice": {
25512575
"description": "voice",
25522576
"type": "string",

astrbot/core/provider/manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,10 @@ def dynamic_import_provider(self, type: str) -> None:
405405
from .sources.mimo_stt_api_source import (
406406
ProviderMiMoSTTAPI as ProviderMiMoSTTAPI,
407407
)
408+
case "stepfun_asr":
409+
from .sources.stepfun_asr_source import (
410+
ProviderStepFunASR as ProviderStepFunASR,
411+
)
408412
case "openai_whisper_selfhost":
409413
from .sources.whisper_selfhosted_source import (
410414
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
import base64
2+
import json
3+
import uuid
4+
from pathlib import Path
5+
from urllib.parse import urlparse
6+
7+
import httpx
8+
9+
from astrbot import logger
10+
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
11+
from astrbot.core.utils.io import download_file
12+
from astrbot.core.utils.media_utils import convert_audio_to_wav
13+
from astrbot.core.utils.tencent_record_helper import (
14+
convert_to_pcm_wav,
15+
tencent_silk_to_wav,
16+
)
17+
18+
from ..entities import ProviderType
19+
from ..provider import STTProvider
20+
from ..register import register_provider_adapter
21+
22+
DEFAULT_STEPFUN_ASR_API_BASE = "https://api.stepfun.com/step_plan/v1"
23+
DEFAULT_STEPFUN_ASR_MODEL = "stepaudio-2.5-asr"
24+
25+
SUPPORTED_AUDIO_FORMATS = {"mp3", "wav", "ogg", "pcm"}
26+
27+
28+
class StepFunASRError(Exception):
29+
pass
30+
31+
32+
def normalize_timeout(timeout: int | str | None) -> int | None:
33+
if timeout in (None, ""):
34+
return None
35+
if isinstance(timeout, str):
36+
return int(timeout)
37+
return timeout
38+
39+
40+
def normalize_bool(value: bool | str | int | None, default: bool) -> bool:
41+
if value is None or value == "":
42+
return default
43+
if isinstance(value, bool):
44+
return value
45+
if isinstance(value, int):
46+
return value != 0
47+
return str(value).strip().lower() in {"1", "true", "yes", "on", "enabled"}
48+
49+
50+
def build_api_url(api_base: str) -> str:
51+
normalized_api_base = (api_base or DEFAULT_STEPFUN_ASR_API_BASE).rstrip("/")
52+
if normalized_api_base.endswith("/audio/asr/sse"):
53+
return normalized_api_base
54+
return normalized_api_base + "/audio/asr/sse"
55+
56+
57+
def build_headers(api_key: str) -> dict[str, str]:
58+
headers = {
59+
"Content-Type": "application/json",
60+
"Accept": "text/event-stream",
61+
}
62+
if api_key:
63+
headers["Authorization"] = f"Bearer {api_key}"
64+
return headers
65+
66+
67+
def create_http_client(timeout: int | None, proxy: str) -> httpx.AsyncClient:
68+
client_kwargs: dict[str, object] = {
69+
"timeout": timeout,
70+
"follow_redirects": True,
71+
}
72+
if proxy:
73+
logger.info("[StepFun ASR] Using proxy: %s", proxy)
74+
client_kwargs["proxy"] = proxy
75+
return httpx.AsyncClient(**client_kwargs)
76+
77+
78+
def get_temp_dir() -> Path:
79+
temp_dir = Path(get_astrbot_temp_path())
80+
temp_dir.mkdir(parents=True, exist_ok=True)
81+
return temp_dir
82+
83+
84+
async def detect_audio_format(file_path: Path) -> str | None:
85+
try:
86+
with file_path.open("rb") as file:
87+
header = file.read(64)
88+
except FileNotFoundError:
89+
return None
90+
91+
if header[:4] == b"RIFF" and header[8:12] == b"WAVE":
92+
return "wav"
93+
if header.startswith(b"#!AMR"):
94+
return "amr"
95+
if (
96+
header.startswith(b"#!SILK_V3")
97+
or header.startswith(b"\x02#!SILK_V3")
98+
or b"SILK" in header[:16]
99+
):
100+
return "silk"
101+
if header.startswith(b"OggS"):
102+
return "ogg"
103+
if header[:3] == b"ID3" or header[:2] == b"\xff\xfb":
104+
return "mp3"
105+
return None
106+
107+
108+
def build_audio_format(audio_type: str) -> dict[str, object]:
109+
if audio_type == "pcm":
110+
return {
111+
"type": "pcm",
112+
"codec": "pcm_s16le",
113+
"rate": 16000,
114+
"bits": 16,
115+
"channel": 1,
116+
}
117+
return {"type": audio_type}
118+
119+
120+
async def prepare_audio_input(
121+
audio_source: str,
122+
) -> tuple[str, dict[str, object], list[Path]]:
123+
cleanup_paths: list[Path] = []
124+
source_path = Path(audio_source)
125+
is_remote = audio_source.startswith(("http://", "https://"))
126+
is_tencent = "multimedia.nt.qq.com.cn" in audio_source if is_remote else False
127+
128+
if is_remote:
129+
parsed_url = urlparse(audio_source)
130+
suffix = Path(parsed_url.path).suffix or ".input"
131+
download_path = get_temp_dir() / f"stepfun_asr_{uuid.uuid4().hex[:8]}{suffix}"
132+
await download_file(audio_source, str(download_path))
133+
source_path = download_path
134+
cleanup_paths.append(download_path)
135+
136+
if not source_path.exists():
137+
raise FileNotFoundError(f"File does not exist: {source_path}")
138+
139+
audio_type = await detect_audio_format(source_path)
140+
if audio_type is None:
141+
audio_type = source_path.suffix.lower().lstrip(".")
142+
143+
if source_path.suffix.lower() in {".amr", ".silk"} or is_tencent:
144+
file_format = await detect_audio_format(source_path)
145+
if file_format in {"silk", "amr"}:
146+
converted_path = get_temp_dir() / f"stepfun_asr_{uuid.uuid4().hex[:8]}.wav"
147+
cleanup_paths.append(converted_path)
148+
if file_format == "silk":
149+
logger.info("Converting silk file to wav for StepFun ASR...")
150+
await tencent_silk_to_wav(str(source_path), str(converted_path))
151+
else:
152+
logger.info("Converting amr file to wav for StepFun ASR...")
153+
await convert_to_pcm_wav(str(source_path), str(converted_path))
154+
source_path = converted_path
155+
audio_type = "wav"
156+
157+
if audio_type not in SUPPORTED_AUDIO_FORMATS:
158+
converted_path = get_temp_dir() / f"stepfun_asr_{uuid.uuid4().hex[:8]}.wav"
159+
cleanup_paths.append(converted_path)
160+
logger.info("Converting audio file to wav for StepFun ASR...")
161+
await convert_audio_to_wav(str(source_path), str(converted_path))
162+
source_path = converted_path
163+
audio_type = "wav"
164+
165+
encoded_audio = base64.b64encode(source_path.read_bytes()).decode("utf-8")
166+
return encoded_audio, build_audio_format(audio_type), cleanup_paths
167+
168+
169+
def cleanup_files(paths: list[Path]) -> None:
170+
for path in paths:
171+
try:
172+
path.unlink(missing_ok=True)
173+
except Exception as exc:
174+
logger.warning(
175+
"Failed to remove temporary StepFun ASR file %s: %s",
176+
path,
177+
exc,
178+
)
179+
180+
181+
def _iter_sse_payloads(content: str):
182+
for event in content.split("\n\n"):
183+
data_lines = []
184+
for line in event.splitlines():
185+
if line.startswith("data:"):
186+
data_lines.append(line[5:].strip())
187+
if not data_lines and event.strip().startswith("{"):
188+
data_lines.append(event.strip())
189+
for data in data_lines:
190+
if data and data != "[DONE]":
191+
yield data
192+
193+
194+
def _text_candidate(data: dict) -> str:
195+
for key in ("text", "transcript", "content", "delta"):
196+
value = data.get(key)
197+
if isinstance(value, str) and value:
198+
return value
199+
200+
nested_data = data.get("data")
201+
if isinstance(nested_data, dict):
202+
for key in ("text", "transcript", "content", "delta"):
203+
value = nested_data.get(key)
204+
if isinstance(value, str) and value:
205+
return value
206+
207+
choices = data.get("choices")
208+
if isinstance(choices, list) and choices:
209+
first_choice = choices[0]
210+
if isinstance(first_choice, dict):
211+
delta = first_choice.get("delta")
212+
if isinstance(delta, dict):
213+
content = delta.get("content")
214+
if isinstance(content, str) and content:
215+
return content
216+
message = first_choice.get("message")
217+
if isinstance(message, dict):
218+
content = message.get("content")
219+
if isinstance(content, str) and content:
220+
return content
221+
return ""
222+
223+
224+
def parse_sse_transcription(content: str) -> str:
225+
done_text = ""
226+
delta_parts: list[str] = []
227+
228+
for payload in _iter_sse_payloads(content):
229+
try:
230+
data = json.loads(payload)
231+
except json.JSONDecodeError:
232+
continue
233+
234+
event_type = str(data.get("type") or data.get("event") or "")
235+
error = data.get("error")
236+
if event_type == "error" or error:
237+
message = data.get("message") or error or data
238+
raise StepFunASRError(f"StepFun ASR returned error: {message}")
239+
240+
text = _text_candidate(data)
241+
if not text:
242+
continue
243+
if event_type.endswith(".done"):
244+
done_text = text
245+
else:
246+
delta_parts.append(text)
247+
248+
result = done_text or "".join(delta_parts)
249+
if not result.strip():
250+
raise StepFunASRError("StepFun ASR returned empty transcription")
251+
return result.strip()
252+
253+
254+
@register_provider_adapter(
255+
"stepfun_asr",
256+
"StepFun StepAudio ASR API",
257+
provider_type=ProviderType.SPEECH_TO_TEXT,
258+
)
259+
class ProviderStepFunASR(STTProvider):
260+
def __init__(
261+
self,
262+
provider_config: dict,
263+
provider_settings: dict,
264+
) -> None:
265+
super().__init__(provider_config, provider_settings)
266+
self.chosen_api_key = provider_config.get("api_key", "")
267+
self.api_base = provider_config.get(
268+
"api_base",
269+
DEFAULT_STEPFUN_ASR_API_BASE,
270+
)
271+
self.proxy = provider_config.get("proxy", "")
272+
self.timeout = normalize_timeout(provider_config.get("timeout", 20))
273+
self.language = provider_config.get("stepfun-asr-language", "zh")
274+
self.enable_itn = normalize_bool(
275+
provider_config.get("stepfun-asr-enable-itn", True),
276+
True,
277+
)
278+
self.set_model(provider_config.get("model", DEFAULT_STEPFUN_ASR_MODEL))
279+
self.client = create_http_client(self.timeout, self.proxy)
280+
281+
async def prepare_audio_input(
282+
self,
283+
audio_source: str,
284+
) -> tuple[str, dict[str, object], list[Path]]:
285+
return await prepare_audio_input(audio_source)
286+
287+
def _build_transcription_options(self) -> dict[str, object]:
288+
transcription: dict[str, object] = {
289+
"model": self.model_name,
290+
"language": self.language,
291+
"enable_itn": self.enable_itn,
292+
}
293+
return transcription
294+
295+
async def get_text(self, audio_url: str) -> str:
296+
encoded_audio, audio_format, cleanup_paths = await self.prepare_audio_input(
297+
audio_url
298+
)
299+
payload = {
300+
"audio": {
301+
"data": encoded_audio,
302+
"input": {
303+
"transcription": self._build_transcription_options(),
304+
"format": audio_format,
305+
},
306+
},
307+
}
308+
309+
try:
310+
response = await self.client.post(
311+
build_api_url(self.api_base),
312+
headers=build_headers(self.chosen_api_key),
313+
json=payload,
314+
)
315+
try:
316+
response.raise_for_status()
317+
except Exception as exc:
318+
error_text = response.text[:1024]
319+
raise StepFunASRError(
320+
"StepFun ASR API request failed: "
321+
f"HTTP {response.status_code}, response: {error_text}"
322+
) from exc
323+
324+
return parse_sse_transcription(response.text)
325+
finally:
326+
cleanup_files(cleanup_paths)
327+
328+
async def terminate(self):
329+
if self.client:
330+
await self.client.aclose()

dashboard/src/utils/providerUtils.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ export function getProviderIcon(type) {
3535
'minimax': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/minimax.svg',
3636
'minimax-token-plan': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/minimax.svg',
3737
'mimo': 'https://platform.xiaomimimo.com/favicon.874c9507.png',
38+
'stepfun': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/stepfun-color.svg',
3839
'302ai': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@1.53.0/icons/ai302-color.svg',
3940
'microsoft': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/microsoft.svg',
4041
'vllm': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/vllm.svg',

0 commit comments

Comments
 (0)