|
| 1 | +import base64 |
| 2 | +import json |
| 3 | +import uuid |
| 4 | +from pathlib import Path |
| 5 | +from urllib.parse import urlparse |
| 6 | + |
| 7 | +import httpx |
| 8 | + |
| 9 | +from astrbot import logger |
| 10 | +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path |
| 11 | +from astrbot.core.utils.io import download_file |
| 12 | +from astrbot.core.utils.media_utils import convert_audio_to_wav |
| 13 | +from astrbot.core.utils.tencent_record_helper import ( |
| 14 | + convert_to_pcm_wav, |
| 15 | + tencent_silk_to_wav, |
| 16 | +) |
| 17 | + |
| 18 | +from ..entities import ProviderType |
| 19 | +from ..provider import STTProvider |
| 20 | +from ..register import register_provider_adapter |
| 21 | + |
| 22 | +DEFAULT_STEPFUN_ASR_API_BASE = "https://api.stepfun.com/step_plan/v1" |
| 23 | +DEFAULT_STEPFUN_ASR_MODEL = "stepaudio-2.5-asr" |
| 24 | + |
| 25 | +SUPPORTED_AUDIO_FORMATS = {"mp3", "wav", "ogg", "pcm"} |
| 26 | + |
| 27 | + |
| 28 | +class StepFunASRError(Exception): |
| 29 | + pass |
| 30 | + |
| 31 | + |
| 32 | +def normalize_timeout(timeout: int | str | None) -> int | None: |
| 33 | + if timeout in (None, ""): |
| 34 | + return None |
| 35 | + if isinstance(timeout, str): |
| 36 | + return int(timeout) |
| 37 | + return timeout |
| 38 | + |
| 39 | + |
| 40 | +def normalize_bool(value: bool | str | int | None, default: bool) -> bool: |
| 41 | + if value is None or value == "": |
| 42 | + return default |
| 43 | + if isinstance(value, bool): |
| 44 | + return value |
| 45 | + if isinstance(value, int): |
| 46 | + return value != 0 |
| 47 | + return str(value).strip().lower() in {"1", "true", "yes", "on", "enabled"} |
| 48 | + |
| 49 | + |
| 50 | +def build_api_url(api_base: str) -> str: |
| 51 | + normalized_api_base = (api_base or DEFAULT_STEPFUN_ASR_API_BASE).rstrip("/") |
| 52 | + if normalized_api_base.endswith("/audio/asr/sse"): |
| 53 | + return normalized_api_base |
| 54 | + return normalized_api_base + "/audio/asr/sse" |
| 55 | + |
| 56 | + |
| 57 | +def build_headers(api_key: str) -> dict[str, str]: |
| 58 | + headers = { |
| 59 | + "Content-Type": "application/json", |
| 60 | + "Accept": "text/event-stream", |
| 61 | + } |
| 62 | + if api_key: |
| 63 | + headers["Authorization"] = f"Bearer {api_key}" |
| 64 | + return headers |
| 65 | + |
| 66 | + |
| 67 | +def create_http_client(timeout: int | None, proxy: str) -> httpx.AsyncClient: |
| 68 | + client_kwargs: dict[str, object] = { |
| 69 | + "timeout": timeout, |
| 70 | + "follow_redirects": True, |
| 71 | + } |
| 72 | + if proxy: |
| 73 | + logger.info("[StepFun ASR] Using proxy: %s", proxy) |
| 74 | + client_kwargs["proxy"] = proxy |
| 75 | + return httpx.AsyncClient(**client_kwargs) |
| 76 | + |
| 77 | + |
| 78 | +def get_temp_dir() -> Path: |
| 79 | + temp_dir = Path(get_astrbot_temp_path()) |
| 80 | + temp_dir.mkdir(parents=True, exist_ok=True) |
| 81 | + return temp_dir |
| 82 | + |
| 83 | + |
| 84 | +async def detect_audio_format(file_path: Path) -> str | None: |
| 85 | + try: |
| 86 | + with file_path.open("rb") as file: |
| 87 | + header = file.read(64) |
| 88 | + except FileNotFoundError: |
| 89 | + return None |
| 90 | + |
| 91 | + if header[:4] == b"RIFF" and header[8:12] == b"WAVE": |
| 92 | + return "wav" |
| 93 | + if header.startswith(b"#!AMR"): |
| 94 | + return "amr" |
| 95 | + if ( |
| 96 | + header.startswith(b"#!SILK_V3") |
| 97 | + or header.startswith(b"\x02#!SILK_V3") |
| 98 | + or b"SILK" in header[:16] |
| 99 | + ): |
| 100 | + return "silk" |
| 101 | + if header.startswith(b"OggS"): |
| 102 | + return "ogg" |
| 103 | + if header[:3] == b"ID3" or header[:2] == b"\xff\xfb": |
| 104 | + return "mp3" |
| 105 | + return None |
| 106 | + |
| 107 | + |
| 108 | +def build_audio_format(audio_type: str) -> dict[str, object]: |
| 109 | + if audio_type == "pcm": |
| 110 | + return { |
| 111 | + "type": "pcm", |
| 112 | + "codec": "pcm_s16le", |
| 113 | + "rate": 16000, |
| 114 | + "bits": 16, |
| 115 | + "channel": 1, |
| 116 | + } |
| 117 | + return {"type": audio_type} |
| 118 | + |
| 119 | + |
| 120 | +async def prepare_audio_input( |
| 121 | + audio_source: str, |
| 122 | +) -> tuple[str, dict[str, object], list[Path]]: |
| 123 | + cleanup_paths: list[Path] = [] |
| 124 | + source_path = Path(audio_source) |
| 125 | + is_remote = audio_source.startswith(("http://", "https://")) |
| 126 | + is_tencent = "multimedia.nt.qq.com.cn" in audio_source if is_remote else False |
| 127 | + |
| 128 | + if is_remote: |
| 129 | + parsed_url = urlparse(audio_source) |
| 130 | + suffix = Path(parsed_url.path).suffix or ".input" |
| 131 | + download_path = get_temp_dir() / f"stepfun_asr_{uuid.uuid4().hex[:8]}{suffix}" |
| 132 | + await download_file(audio_source, str(download_path)) |
| 133 | + source_path = download_path |
| 134 | + cleanup_paths.append(download_path) |
| 135 | + |
| 136 | + if not source_path.exists(): |
| 137 | + raise FileNotFoundError(f"File does not exist: {source_path}") |
| 138 | + |
| 139 | + audio_type = await detect_audio_format(source_path) |
| 140 | + if audio_type is None: |
| 141 | + audio_type = source_path.suffix.lower().lstrip(".") |
| 142 | + |
| 143 | + if source_path.suffix.lower() in {".amr", ".silk"} or is_tencent: |
| 144 | + file_format = await detect_audio_format(source_path) |
| 145 | + if file_format in {"silk", "amr"}: |
| 146 | + converted_path = get_temp_dir() / f"stepfun_asr_{uuid.uuid4().hex[:8]}.wav" |
| 147 | + cleanup_paths.append(converted_path) |
| 148 | + if file_format == "silk": |
| 149 | + logger.info("Converting silk file to wav for StepFun ASR...") |
| 150 | + await tencent_silk_to_wav(str(source_path), str(converted_path)) |
| 151 | + else: |
| 152 | + logger.info("Converting amr file to wav for StepFun ASR...") |
| 153 | + await convert_to_pcm_wav(str(source_path), str(converted_path)) |
| 154 | + source_path = converted_path |
| 155 | + audio_type = "wav" |
| 156 | + |
| 157 | + if audio_type not in SUPPORTED_AUDIO_FORMATS: |
| 158 | + converted_path = get_temp_dir() / f"stepfun_asr_{uuid.uuid4().hex[:8]}.wav" |
| 159 | + cleanup_paths.append(converted_path) |
| 160 | + logger.info("Converting audio file to wav for StepFun ASR...") |
| 161 | + await convert_audio_to_wav(str(source_path), str(converted_path)) |
| 162 | + source_path = converted_path |
| 163 | + audio_type = "wav" |
| 164 | + |
| 165 | + encoded_audio = base64.b64encode(source_path.read_bytes()).decode("utf-8") |
| 166 | + return encoded_audio, build_audio_format(audio_type), cleanup_paths |
| 167 | + |
| 168 | + |
| 169 | +def cleanup_files(paths: list[Path]) -> None: |
| 170 | + for path in paths: |
| 171 | + try: |
| 172 | + path.unlink(missing_ok=True) |
| 173 | + except Exception as exc: |
| 174 | + logger.warning( |
| 175 | + "Failed to remove temporary StepFun ASR file %s: %s", |
| 176 | + path, |
| 177 | + exc, |
| 178 | + ) |
| 179 | + |
| 180 | + |
| 181 | +def _iter_sse_payloads(content: str): |
| 182 | + for event in content.split("\n\n"): |
| 183 | + data_lines = [] |
| 184 | + for line in event.splitlines(): |
| 185 | + if line.startswith("data:"): |
| 186 | + data_lines.append(line[5:].strip()) |
| 187 | + if not data_lines and event.strip().startswith("{"): |
| 188 | + data_lines.append(event.strip()) |
| 189 | + for data in data_lines: |
| 190 | + if data and data != "[DONE]": |
| 191 | + yield data |
| 192 | + |
| 193 | + |
| 194 | +def _text_candidate(data: dict) -> str: |
| 195 | + for key in ("text", "transcript", "content", "delta"): |
| 196 | + value = data.get(key) |
| 197 | + if isinstance(value, str) and value: |
| 198 | + return value |
| 199 | + |
| 200 | + nested_data = data.get("data") |
| 201 | + if isinstance(nested_data, dict): |
| 202 | + for key in ("text", "transcript", "content", "delta"): |
| 203 | + value = nested_data.get(key) |
| 204 | + if isinstance(value, str) and value: |
| 205 | + return value |
| 206 | + |
| 207 | + choices = data.get("choices") |
| 208 | + if isinstance(choices, list) and choices: |
| 209 | + first_choice = choices[0] |
| 210 | + if isinstance(first_choice, dict): |
| 211 | + delta = first_choice.get("delta") |
| 212 | + if isinstance(delta, dict): |
| 213 | + content = delta.get("content") |
| 214 | + if isinstance(content, str) and content: |
| 215 | + return content |
| 216 | + message = first_choice.get("message") |
| 217 | + if isinstance(message, dict): |
| 218 | + content = message.get("content") |
| 219 | + if isinstance(content, str) and content: |
| 220 | + return content |
| 221 | + return "" |
| 222 | + |
| 223 | + |
| 224 | +def parse_sse_transcription(content: str) -> str: |
| 225 | + done_text = "" |
| 226 | + delta_parts: list[str] = [] |
| 227 | + |
| 228 | + for payload in _iter_sse_payloads(content): |
| 229 | + try: |
| 230 | + data = json.loads(payload) |
| 231 | + except json.JSONDecodeError: |
| 232 | + continue |
| 233 | + |
| 234 | + event_type = str(data.get("type") or data.get("event") or "") |
| 235 | + error = data.get("error") |
| 236 | + if event_type == "error" or error: |
| 237 | + message = data.get("message") or error or data |
| 238 | + raise StepFunASRError(f"StepFun ASR returned error: {message}") |
| 239 | + |
| 240 | + text = _text_candidate(data) |
| 241 | + if not text: |
| 242 | + continue |
| 243 | + if event_type.endswith(".done"): |
| 244 | + done_text = text |
| 245 | + else: |
| 246 | + delta_parts.append(text) |
| 247 | + |
| 248 | + result = done_text or "".join(delta_parts) |
| 249 | + if not result.strip(): |
| 250 | + raise StepFunASRError("StepFun ASR returned empty transcription") |
| 251 | + return result.strip() |
| 252 | + |
| 253 | + |
| 254 | +@register_provider_adapter( |
| 255 | + "stepfun_asr", |
| 256 | + "StepFun StepAudio ASR API", |
| 257 | + provider_type=ProviderType.SPEECH_TO_TEXT, |
| 258 | +) |
| 259 | +class ProviderStepFunASR(STTProvider): |
| 260 | + def __init__( |
| 261 | + self, |
| 262 | + provider_config: dict, |
| 263 | + provider_settings: dict, |
| 264 | + ) -> None: |
| 265 | + super().__init__(provider_config, provider_settings) |
| 266 | + self.chosen_api_key = provider_config.get("api_key", "") |
| 267 | + self.api_base = provider_config.get( |
| 268 | + "api_base", |
| 269 | + DEFAULT_STEPFUN_ASR_API_BASE, |
| 270 | + ) |
| 271 | + self.proxy = provider_config.get("proxy", "") |
| 272 | + self.timeout = normalize_timeout(provider_config.get("timeout", 20)) |
| 273 | + self.language = provider_config.get("stepfun-asr-language", "zh") |
| 274 | + self.enable_itn = normalize_bool( |
| 275 | + provider_config.get("stepfun-asr-enable-itn", True), |
| 276 | + True, |
| 277 | + ) |
| 278 | + self.set_model(provider_config.get("model", DEFAULT_STEPFUN_ASR_MODEL)) |
| 279 | + self.client = create_http_client(self.timeout, self.proxy) |
| 280 | + |
| 281 | + async def prepare_audio_input( |
| 282 | + self, |
| 283 | + audio_source: str, |
| 284 | + ) -> tuple[str, dict[str, object], list[Path]]: |
| 285 | + return await prepare_audio_input(audio_source) |
| 286 | + |
| 287 | + def _build_transcription_options(self) -> dict[str, object]: |
| 288 | + transcription: dict[str, object] = { |
| 289 | + "model": self.model_name, |
| 290 | + "language": self.language, |
| 291 | + "enable_itn": self.enable_itn, |
| 292 | + } |
| 293 | + return transcription |
| 294 | + |
| 295 | + async def get_text(self, audio_url: str) -> str: |
| 296 | + encoded_audio, audio_format, cleanup_paths = await self.prepare_audio_input( |
| 297 | + audio_url |
| 298 | + ) |
| 299 | + payload = { |
| 300 | + "audio": { |
| 301 | + "data": encoded_audio, |
| 302 | + "input": { |
| 303 | + "transcription": self._build_transcription_options(), |
| 304 | + "format": audio_format, |
| 305 | + }, |
| 306 | + }, |
| 307 | + } |
| 308 | + |
| 309 | + try: |
| 310 | + response = await self.client.post( |
| 311 | + build_api_url(self.api_base), |
| 312 | + headers=build_headers(self.chosen_api_key), |
| 313 | + json=payload, |
| 314 | + ) |
| 315 | + try: |
| 316 | + response.raise_for_status() |
| 317 | + except Exception as exc: |
| 318 | + error_text = response.text[:1024] |
| 319 | + raise StepFunASRError( |
| 320 | + "StepFun ASR API request failed: " |
| 321 | + f"HTTP {response.status_code}, response: {error_text}" |
| 322 | + ) from exc |
| 323 | + |
| 324 | + return parse_sse_transcription(response.text) |
| 325 | + finally: |
| 326 | + cleanup_files(cleanup_paths) |
| 327 | + |
| 328 | + async def terminate(self): |
| 329 | + if self.client: |
| 330 | + await self.client.aclose() |
0 commit comments