Skip to content

Commit 9b75c94

Browse files
feat: add speech-to-text for audio (WAV) via partition_audio and optional Whisper STT agent
1 parent d0f8620 commit 9b75c94

11 files changed

Lines changed: 319 additions & 5 deletions

File tree

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## 0.21.7
2+
3+
### Enhancements
4+
- **Add speech-to-text to multimodal pipeline**: Audio files (WAV) can now be partitioned into document elements via speech-to-text. Install the optional `audio` extra (`pip install "unstructured[audio]"`) to use the Whisper-based partitioner. Call `partition()` or `partition_audio()` with a WAV file to get a transcript as `NarrativeText` elements. The `STT_AGENT` environment variable selects the speech-to-text implementation (default: Whisper).
5+
16
## 0.21.6
27

38
### Enhancements

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,12 @@ xlsx = [
111111
"pandas>=2.0.0, <4.0.0",
112112
"xlrd>=2.0.1, <3.0.0",
113113
]
114+
# Speech-to-text for partition_audio (multimodal: audio -> elements)
115+
audio = [
116+
"openai-whisper>=20231117, <20260000",
117+
]
114118
all-docs = [
115-
"unstructured[csv,doc,docx,epub,image,md,odt,org,pdf,ppt,pptx,rtf,rst,tsv,xlsx]",
119+
"unstructured[audio,csv,doc,docx,epub,image,md,odt,org,pdf,ppt,pptx,rtf,rst,tsv,xlsx]",
116120
]
117121
# Feature extras
118122
chunking-tokens = [
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# pyright: reportPrivateUsage=false
2+
3+
"""Tests for partition_audio (speech-to-text in multimodal pipeline)."""
4+
5+
from __future__ import annotations
6+
7+
import tempfile
8+
from pathlib import Path
9+
from unittest.mock import patch
10+
11+
import pytest
12+
13+
from unstructured.documents.elements import NarrativeText
14+
from unstructured.file_utils.model import FileType
15+
from unstructured.partition.audio import partition_audio
16+
17+
18+
def test_partition_audio_raises_with_neither_filename_nor_file():
19+
with pytest.raises(ValueError, match="Exactly one of .* must be specified"):
20+
partition_audio()
21+
22+
23+
def test_partition_audio_raises_with_both_filename_and_file():
24+
with pytest.raises(ValueError, match="Exactly one of .* must be specified"):
25+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
26+
partition_audio(filename=tmp.name, file=tmp)
27+
28+
29+
@patch(
30+
"unstructured.partition.audio.SpeechToTextAgent.get_instance",
31+
)
32+
def test_partition_audio_from_filename_returns_transcript_elements(mock_get_instance):
33+
mock_agent = mock_get_instance.return_value
34+
mock_agent.transcribe.return_value = "Hello, this is a test transcript."
35+
36+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
37+
path = tmp.name
38+
tmp.write(b"\x00" * 44) # minimal WAV-like header
39+
tmp.flush()
40+
41+
try:
42+
elements = partition_audio(filename=path)
43+
finally:
44+
Path(path).unlink(missing_ok=True)
45+
46+
assert len(elements) == 1
47+
assert isinstance(elements[0], NarrativeText)
48+
assert elements[0].text == "Hello, this is a test transcript."
49+
assert elements[0].metadata.detection_origin == "speech_to_text"
50+
mock_agent.transcribe.assert_called_once_with(path, language=None)
51+
52+
53+
@patch(
54+
"unstructured.partition.audio.SpeechToTextAgent.get_instance",
55+
)
56+
def test_partition_audio_from_file_uses_temp_path_and_cleans_up(mock_get_instance):
57+
mock_agent = mock_get_instance.return_value
58+
mock_agent.transcribe.return_value = "From file object."
59+
60+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
61+
tmp.write(b"\x00" * 44)
62+
tmp.flush()
63+
tmp.seek(0)
64+
elements = partition_audio(file=tmp, metadata_filename="recording.wav")
65+
66+
assert len(elements) == 1
67+
assert elements[0].text == "From file object."
68+
assert elements[0].metadata.filename == "recording.wav"
69+
70+
71+
@patch(
72+
"unstructured.partition.audio.SpeechToTextAgent.get_instance",
73+
)
74+
def test_partition_audio_empty_transcript_returns_empty_list(mock_get_instance):
75+
mock_agent = mock_get_instance.return_value
76+
mock_agent.transcribe.return_value = " "
77+
78+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
79+
path = tmp.name
80+
tmp.write(b"\x00" * 44)
81+
tmp.flush()
82+
83+
try:
84+
elements = partition_audio(filename=path)
85+
finally:
86+
Path(path).unlink(missing_ok=True)
87+
88+
assert elements == []
89+
90+
91+
def test_wav_file_type_is_partitionable():
92+
assert FileType.WAV.is_partitionable
93+
assert FileType.WAV.partitioner_shortname == "audio"
94+
assert FileType.WAV.partitioner_function_name == "partition_audio"

unstructured/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.21.6" # pragma: no cover
1+
__version__ = "0.21.7" # pragma: no cover

unstructured/file_utils/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,9 @@ def partitioner_shortname(self) -> str | None:
441441
)
442442
WAV = (
443443
"wav",
444-
None,
445-
cast(list[str], []),
446-
None,
444+
"audio",
445+
["whisper"],
446+
"audio",
447447
[".wav"],
448448
"audio/wav",
449449
[

unstructured/partition/audio.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""Partition audio files into elements using speech-to-text transcription."""
2+
3+
from __future__ import annotations
4+
5+
import tempfile
6+
from pathlib import Path
7+
from typing import IO, Any
8+
9+
from unstructured.chunking import add_chunking_strategy
10+
from unstructured.documents.elements import Element, NarrativeText
11+
from unstructured.file_utils.model import FileType
12+
from unstructured.partition.common.common import exactly_one
13+
from unstructured.partition.common.metadata import apply_metadata, get_last_modified_date
14+
from unstructured.partition.utils.config import env_config
15+
from unstructured.partition.utils.speech_to_text.speech_to_text_interface import (
16+
SpeechToTextAgent,
17+
)
18+
19+
20+
@apply_metadata(FileType.WAV)
21+
@add_chunking_strategy
22+
def partition_audio(
23+
filename: str | None = None,
24+
*,
25+
file: IO[bytes] | None = None,
26+
language: str | None = None,
27+
stt_agent: str | None = None,
28+
metadata_filename: str | None = None,
29+
metadata_last_modified: str | None = None,
30+
**kwargs: Any,
31+
) -> list[Element]:
32+
"""Partition an audio file (e.g. WAV) into elements using speech-to-text.
33+
34+
Transcribes the audio and returns a single NarrativeText element containing
35+
the full transcript. Requires the optional `audio` extra with Whisper:
36+
``pip install "unstructured[audio]"``.
37+
38+
Parameters
39+
----------
40+
filename
41+
Path to the audio file.
42+
file
43+
File-like object opened in binary mode (e.g. ``open("audio.wav", "rb")``).
44+
language
45+
Optional ISO 639-1 language code for the spoken language (e.g. "en").
46+
When None, the speech-to-text agent may auto-detect.
47+
stt_agent
48+
Optional fully-qualified class name of the SpeechToTextAgent implementation.
49+
Defaults to the Whisper agent when the audio extra is installed.
50+
metadata_filename
51+
Filename to store in element metadata when partitioning from a file object.
52+
metadata_last_modified
53+
Last modified date to store in element metadata.
54+
"""
55+
exactly_one(filename=filename, file=file)
56+
57+
audio_path: str
58+
if filename is not None:
59+
audio_path = filename
60+
else:
61+
if file is None:
62+
raise ValueError("Either filename or file must be provided.")
63+
file.seek(0)
64+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
65+
tmp.write(file.read())
66+
audio_path = tmp.name
67+
68+
try:
69+
agent_module = stt_agent or env_config.STT_AGENT
70+
agent = SpeechToTextAgent.get_instance(agent_module)
71+
text = agent.transcribe(audio_path, language=language)
72+
finally:
73+
if filename is None and audio_path.startswith(tempfile.gettempdir()):
74+
Path(audio_path).unlink(missing_ok=True)
75+
76+
if not text.strip():
77+
return []
78+
79+
metadata_kwargs: dict[str, Any] = {}
80+
if metadata_filename:
81+
metadata_kwargs["filename"] = metadata_filename
82+
elif filename:
83+
metadata_kwargs["filename"] = filename
84+
if metadata_last_modified:
85+
metadata_kwargs["last_modified"] = metadata_last_modified
86+
elif filename:
87+
last_modified = get_last_modified_date(filename)
88+
if last_modified:
89+
metadata_kwargs["last_modified"] = last_modified
90+
91+
element = NarrativeText(text=text)
92+
element.metadata.detection_origin = "speech_to_text"
93+
element.metadata.update(element.metadata.__class__(**metadata_kwargs))
94+
95+
return [element]

unstructured/partition/utils/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@ def OCR_AGENT_CACHE_SIZE(self) -> int:
116116
"""Maximum number of OCR agents to cache per process"""
117117
return self._get_int("OCR_AGENT_CACHE_SIZE", 1)
118118

119+
@property
120+
def STT_AGENT(self) -> str:
121+
"""Speech-to-text agent module for partition_audio (e.g. Whisper)."""
122+
from unstructured.partition.utils.constants import STT_AGENT_WHISPER
123+
124+
return self._get_string("STT_AGENT", STT_AGENT_WHISPER)
125+
119126
@property
120127
def EXTRACT_IMAGE_BLOCK_CROP_HORIZONTAL_PAD(self) -> int:
121128
"""extra image block content to add around an identified element(`Image`, `Table`) region

unstructured/partition/utils/constants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ class PartitionStrategy:
4141
"unstructured.partition.utils.ocr_models.google_vision_ocr",
4242
).split(",")
4343

44+
# Speech-to-text agent (used by partition_audio)
45+
STT_AGENT_WHISPER = "unstructured.partition.utils.speech_to_text.whisper_stt.SpeechToTextAgentWhisper"
46+
STT_AGENT_MODULES_WHITELIST = (
47+
os.getenv(
48+
"STT_AGENT_MODULES_WHITELIST",
49+
"unstructured.partition.utils.speech_to_text.whisper_stt",
50+
).split(",")
51+
)
52+
4453
UNSTRUCTURED_INCLUDE_DEBUG_METADATA = os.getenv("UNSTRUCTURED_INCLUDE_DEBUG_METADATA", False)
4554

4655
# this field is defined by unstructured_pytesseract
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Speech-to-text agents for transcribing audio in the multimodal partition pipeline."""
2+
3+
from unstructured.partition.utils.speech_to_text.speech_to_text_interface import (
4+
SpeechToTextAgent,
5+
)
6+
7+
__all__ = ["SpeechToTextAgent"]
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Abstract interface for speech-to-text (STT) agents used by the audio partitioner."""
2+
3+
from __future__ import annotations
4+
5+
import functools
6+
import importlib
7+
from abc import ABC, abstractmethod
8+
from typing import TYPE_CHECKING
9+
10+
from unstructured.logger import logger
11+
from unstructured.partition.utils.constants import STT_AGENT_MODULES_WHITELIST
12+
13+
if TYPE_CHECKING:
14+
pass
15+
16+
17+
class SpeechToTextAgent(ABC):
18+
"""Defines the interface for a speech-to-text transcription service."""
19+
20+
@staticmethod
21+
@functools.lru_cache(maxsize=1)
22+
def get_instance(agent_module: str) -> "SpeechToTextAgent":
23+
"""Load and return the configured SpeechToTextAgent implementation.
24+
25+
The implementation is determined by the `STT_AGENT` environment variable
26+
or the passed `agent_module` (e.g. whisper implementation).
27+
"""
28+
module_name, class_name = agent_module.rsplit(".", 1)
29+
if module_name not in STT_AGENT_MODULES_WHITELIST:
30+
raise ValueError(
31+
f"Speech-to-text agent module {module_name} must be in the whitelist: "
32+
f"{STT_AGENT_MODULES_WHITELIST}."
33+
)
34+
try:
35+
mod = importlib.import_module(module_name)
36+
cls = getattr(mod, class_name)
37+
return cls()
38+
except (ImportError, AttributeError) as e:
39+
logger.error(f"Failed to get SpeechToTextAgent instance: {e}")
40+
raise RuntimeError(
41+
"Could not load the SpeechToText agent. Install the audio extra: "
42+
'pip install "unstructured[audio]"'
43+
) from e
44+
45+
@abstractmethod
46+
def transcribe(self, audio_path: str, *, language: str | None = None) -> str:
47+
"""Transcribe audio from a file path to text.
48+
49+
Parameters
50+
----------
51+
audio_path
52+
Path to an audio file (e.g. WAV, MP3).
53+
language
54+
Optional ISO 639-1 language code for the spoken language (e.g. "en").
55+
When None, the agent may auto-detect.
56+
57+
Returns
58+
-------
59+
Transcribed text.
60+
"""
61+
pass

0 commit comments

Comments
 (0)