Skip to content

Commit 282431b

Browse files
fix: update
1 parent 70edd11 commit 282431b

7 files changed

Lines changed: 261 additions & 41 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ xlsx = [
113113
]
114114
# Speech-to-text for partition_audio (multimodal: audio -> elements)
115115
audio = [
116-
"openai-whisper>=20231117, <20260000",
116+
"openai-whisper>=20231117, <20270000",
117117
]
118118
all-docs = [
119119
"unstructured[audio,csv,doc,docx,epub,image,md,odt,org,pdf,ppt,pptx,rtf,rst,tsv,xlsx]",

test_unstructured/partition/test_audio.py

Lines changed: 209 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,21 @@ def test_partition_audio_raises_with_neither_filename_nor_file():
2222

2323

2424
def test_partition_audio_raises_with_both_filename_and_file():
25-
with pytest.raises(ValueError, match="Exactly one of .* must be specified"):
26-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
27-
partition_audio(filename=tmp.name, file=tmp)
25+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
26+
path = tmp.name
27+
try:
28+
with pytest.raises(ValueError, match="Exactly one of .* must be specified"):
29+
with open(path, "rb") as f:
30+
partition_audio(filename=path, file=f)
31+
finally:
32+
Path(path).unlink(missing_ok=True)
2833

2934

3035
@patch(
31-
"unstructured.partition.audio.SpeechToTextAgent.get_instance",
36+
"unstructured.partition.audio.SpeechToTextAgent.get_agent",
3237
)
33-
def test_partition_audio_from_filename_returns_transcript_elements(mock_get_instance):
34-
mock_agent = mock_get_instance.return_value
38+
def test_partition_audio_from_filename_returns_transcript_elements(mock_get_agent):
39+
mock_agent = mock_get_agent.return_value
3540
mock_agent.transcribe_segments.return_value = [
3641
{"text": "Hello, this is a test transcript.", "start": 0.0, "end": 2.5},
3742
]
@@ -52,34 +57,73 @@ def test_partition_audio_from_filename_returns_transcript_elements(mock_get_inst
5257
assert elements[0].metadata.detection_origin == "speech_to_text"
5358
assert elements[0].metadata.segment_start_seconds == 0.0
5459
assert elements[0].metadata.segment_end_seconds == 2.5
60+
mock_get_agent.assert_called_once_with(None)
5561
mock_agent.transcribe_segments.assert_called_once_with(path, language=None)
5662

5763

5864
@patch(
59-
"unstructured.partition.audio.SpeechToTextAgent.get_instance",
65+
"unstructured.partition.audio.SpeechToTextAgent.get_agent",
6066
)
61-
def test_partition_audio_from_file_uses_temp_path_and_cleans_up(mock_get_instance):
62-
mock_agent = mock_get_instance.return_value
67+
def test_partition_audio_from_file_uses_temp_path_and_cleans_up(mock_get_agent):
68+
mock_agent = mock_get_agent.return_value
6369
mock_agent.transcribe_segments.return_value = [
6470
{"text": "From file object.", "start": 0.0, "end": 1.0},
6571
]
6672

73+
captured_temp_path: list[str] = []
74+
real_named_temp = tempfile.NamedTemporaryFile
75+
76+
def spy_named_temp(*args, **kwargs):
77+
ctx = real_named_temp(*args, **kwargs)
78+
captured_temp_path.append(ctx.name)
79+
return ctx
80+
6781
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
6882
tmp.write(b"\x00" * 44)
6983
tmp.flush()
7084
tmp.seek(0)
71-
elements = partition_audio(file=tmp, metadata_filename="recording.wav")
85+
with patch("unstructured.partition.audio.tempfile.NamedTemporaryFile", spy_named_temp):
86+
elements = partition_audio(file=tmp, metadata_filename="recording.wav")
7287

7388
assert len(elements) == 1
7489
assert elements[0].text == "From file object."
7590
assert elements[0].metadata.filename == "recording.wav"
91+
assert len(captured_temp_path) == 1, "expected exactly one temp file to be created"
92+
assert not Path(captured_temp_path[0]).exists(), "temp file was not deleted after partitioning"
93+
94+
95+
@patch(
96+
"unstructured.partition.audio.SpeechToTextAgent.get_agent",
97+
)
98+
def test_partition_audio_cleans_up_temp_file_when_transcription_raises(mock_get_agent):
99+
mock_agent = mock_get_agent.return_value
100+
mock_agent.transcribe_segments.side_effect = RuntimeError("transcription failed")
101+
102+
captured_temp_path: list[str] = []
103+
real_named_temp = tempfile.NamedTemporaryFile
104+
105+
def spy_named_temp(*args, **kwargs):
106+
ctx = real_named_temp(*args, **kwargs)
107+
captured_temp_path.append(ctx.name)
108+
return ctx
109+
110+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
111+
tmp.write(b"\x00" * 44)
112+
tmp.flush()
113+
tmp.seek(0)
114+
with patch("unstructured.partition.audio.tempfile.NamedTemporaryFile", spy_named_temp):
115+
with pytest.raises(RuntimeError, match="transcription failed"):
116+
partition_audio(file=tmp)
117+
118+
assert len(captured_temp_path) == 1, "expected exactly one temp file to be created"
119+
assert not Path(captured_temp_path[0]).exists(), "temp file was not deleted after exception"
76120

77121

78122
@patch(
79-
"unstructured.partition.audio.SpeechToTextAgent.get_instance",
123+
"unstructured.partition.audio.SpeechToTextAgent.get_agent",
80124
)
81-
def test_partition_audio_empty_transcript_returns_empty_list(mock_get_instance):
82-
mock_agent = mock_get_instance.return_value
125+
def test_partition_audio_empty_transcript_returns_empty_list(mock_get_agent):
126+
mock_agent = mock_get_agent.return_value
83127
mock_agent.transcribe_segments.return_value = []
84128

85129
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
@@ -96,10 +140,10 @@ def test_partition_audio_empty_transcript_returns_empty_list(mock_get_instance):
96140

97141

98142
@patch(
99-
"unstructured.partition.audio.SpeechToTextAgent.get_instance",
143+
"unstructured.partition.audio.SpeechToTextAgent.get_agent",
100144
)
101-
def test_partition_audio_returns_one_element_per_segment(mock_get_instance):
102-
mock_agent = mock_get_instance.return_value
145+
def test_partition_audio_returns_one_element_per_segment(mock_get_agent):
146+
mock_agent = mock_get_agent.return_value
103147
mock_agent.transcribe_segments.return_value = [
104148
{"text": "First segment.", "start": 0.0, "end": 1.0},
105149
{"text": "Second segment.", "start": 1.0, "end": 2.5},
@@ -153,3 +197,152 @@ def test_wav_file_type_is_partitionable():
153197
assert FileType.WAV.is_partitionable
154198
assert FileType.WAV.partitioner_shortname == "audio"
155199
assert FileType.WAV.partitioner_function_name == "partition_audio"
200+
201+
202+
# ================================================================================================
203+
# partition_audio parameter forwarding
204+
# ================================================================================================
205+
206+
207+
@patch("unstructured.partition.audio.SpeechToTextAgent.get_agent")
208+
def test_partition_audio_forwards_custom_stt_agent_to_get_agent(mock_get_agent):
209+
mock_agent = mock_get_agent.return_value
210+
mock_agent.transcribe_segments.return_value = [
211+
{"text": "Custom agent output.", "start": 0.0, "end": 1.0},
212+
]
213+
custom_module = (
214+
"unstructured.partition.utils.speech_to_text.whisper_stt.SpeechToTextAgentWhisper"
215+
)
216+
217+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
218+
path = tmp.name
219+
tmp.write(b"\x00" * 44)
220+
tmp.flush()
221+
222+
try:
223+
partition_audio(filename=path, stt_agent=custom_module)
224+
finally:
225+
Path(path).unlink(missing_ok=True)
226+
227+
mock_get_agent.assert_called_once_with(custom_module)
228+
229+
230+
@patch("unstructured.partition.audio.SpeechToTextAgent.get_agent")
231+
def test_partition_audio_forwards_language_to_transcribe_segments(mock_get_agent):
232+
mock_agent = mock_get_agent.return_value
233+
mock_agent.transcribe_segments.return_value = [
234+
{"text": "Hola mundo.", "start": 0.0, "end": 1.5},
235+
]
236+
237+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
238+
path = tmp.name
239+
tmp.write(b"\x00" * 44)
240+
tmp.flush()
241+
242+
try:
243+
elements = partition_audio(filename=path, language="es")
244+
finally:
245+
Path(path).unlink(missing_ok=True)
246+
247+
mock_agent.transcribe_segments.assert_called_once_with(path, language="es")
248+
assert elements[0].text == "Hola mundo."
249+
250+
251+
# ================================================================================================
252+
# Whitespace-only segment filtering
253+
# ================================================================================================
254+
255+
256+
@patch("unstructured.partition.audio.SpeechToTextAgent.get_agent")
257+
def test_partition_audio_filters_whitespace_only_segments(mock_get_agent):
258+
mock_agent = mock_get_agent.return_value
259+
mock_agent.transcribe_segments.return_value = [
260+
{"text": " ", "start": 0.0, "end": 0.5},
261+
{"text": "Real content.", "start": 0.5, "end": 2.0},
262+
{"text": "\t\n", "start": 2.0, "end": 2.5},
263+
]
264+
265+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
266+
path = tmp.name
267+
tmp.write(b"\x00" * 44)
268+
tmp.flush()
269+
270+
try:
271+
elements = partition_audio(filename=path)
272+
finally:
273+
Path(path).unlink(missing_ok=True)
274+
275+
assert len(elements) == 1
276+
assert elements[0].text == "Real content."
277+
278+
279+
# ================================================================================================
280+
# SpeechToTextAgent unit tests
281+
# ================================================================================================
282+
283+
284+
class TestSpeechToTextAgentInterface:
285+
"""Unit tests for the SpeechToTextAgent base class."""
286+
287+
def test_get_agent_uses_env_config_when_no_module_given(self):
288+
from unittest.mock import patch as _patch
289+
290+
from unstructured.partition.utils.speech_to_text.speech_to_text_interface import (
291+
SpeechToTextAgent,
292+
)
293+
294+
with _patch.object(SpeechToTextAgent, "get_instance") as mock_get_instance:
295+
SpeechToTextAgent.get_agent(None)
296+
called_with = mock_get_instance.call_args[0][0]
297+
assert "SpeechToTextAgent" in called_with or "Whisper" in called_with
298+
299+
def test_get_agent_passes_explicit_module_to_get_instance(self):
300+
from unittest.mock import patch as _patch
301+
302+
from unstructured.partition.utils.speech_to_text.speech_to_text_interface import (
303+
SpeechToTextAgent,
304+
)
305+
306+
custom = "unstructured.partition.utils.speech_to_text.whisper_stt.SpeechToTextAgentWhisper"
307+
with _patch.object(SpeechToTextAgent, "get_instance") as mock_get_instance:
308+
SpeechToTextAgent.get_agent(custom)
309+
mock_get_instance.assert_called_once_with(custom)
310+
311+
def test_get_instance_rejects_non_whitelisted_module(self):
312+
from unstructured.partition.utils.speech_to_text.speech_to_text_interface import (
313+
SpeechToTextAgent,
314+
)
315+
316+
with pytest.raises(ValueError, match="must be in the whitelist"):
317+
SpeechToTextAgent.get_instance("evil.module.EvilAgent")
318+
319+
def test_transcribe_segments_default_delegates_to_transcribe(self):
320+
"""Base transcribe_segments() wraps transcribe() in a single segment."""
321+
322+
from unstructured.partition.utils.speech_to_text.speech_to_text_interface import (
323+
SpeechToTextAgent,
324+
)
325+
326+
# Create a minimal concrete subclass
327+
class _StubAgent(SpeechToTextAgent):
328+
def transcribe(self, audio_path: str, *, language=None) -> str:
329+
return "stub text"
330+
331+
agent = _StubAgent()
332+
segments = agent.transcribe_segments("fake.wav")
333+
assert len(segments) == 1
334+
assert segments[0]["text"] == "stub text"
335+
assert segments[0]["start"] == 0.0
336+
assert segments[0]["end"] == 0.0
337+
338+
def test_transcribe_segments_default_returns_empty_for_blank_text(self):
339+
from unstructured.partition.utils.speech_to_text.speech_to_text_interface import (
340+
SpeechToTextAgent,
341+
)
342+
343+
class _BlankAgent(SpeechToTextAgent):
344+
def transcribe(self, audio_path: str, *, language=None) -> str:
345+
return " "
346+
347+
agent = _BlankAgent()
348+
assert agent.transcribe_segments("fake.wav") == []

unstructured/documents/elements.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,10 @@ def field_consolidation_strategies(cls) -> dict[str, ConsolidationStrategy]:
527527
"text_as_html": cls.STRING_CONCATENATE,
528528
"table_as_cells": cls.FIRST, # -- only occurs in Table --
529529
"url": cls.FIRST,
530+
# TODO: ideally a chunk spanning multiple audio segments would keep min(start) and
531+
# max(end) across its constituent elements. ConsolidationStrategy currently has no
532+
# MIN/MAX variants, so DROP is the safe fallback for now. Add MIN/MAX strategies
533+
# and switch these to cls.MIN / cls.MAX when that work is done.
530534
"segment_start_seconds": cls.DROP,
531535
"segment_end_seconds": cls.DROP,
532536
"key_value_pairs": cls.DROP, # -- only occurs in FormKeysValues --

unstructured/partition/audio.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from unstructured.partition.utils.speech_to_text.speech_to_text_interface import (
1616
SpeechToTextAgent,
1717
)
18+
from unstructured.utils import is_temp_file_path
1819

1920

2021
@apply_metadata(FileType.WAV)
@@ -70,16 +71,13 @@ def partition_audio(
7071
agent = SpeechToTextAgent.get_agent(stt_agent)
7172
segments = agent.transcribe_segments(audio_path, language=language)
7273
finally:
73-
if filename is None and audio_path.startswith(tempfile.gettempdir()):
74+
if filename is None and is_temp_file_path(audio_path):
7475
Path(audio_path).unlink(missing_ok=True)
7576

7677
if not segments:
7778
return []
7879

79-
base_metadata = ElementMetadata(
80-
last_modified=get_last_modified_date(filename) if filename else None,
81-
)
82-
base_metadata.detection_origin = "speech_to_text"
80+
last_modified = get_last_modified_date(filename) if filename else None
8381

8482
elements: list[Element] = []
8583
for seg in segments:
@@ -88,7 +86,7 @@ def partition_audio(
8886
continue
8987
element = NarrativeText(text=text)
9088
element.metadata = ElementMetadata(
91-
last_modified=base_metadata.last_modified,
89+
last_modified=last_modified,
9290
segment_start_seconds=seg["start"],
9391
segment_end_seconds=seg["end"],
9492
)

unstructured/partition/utils/config.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pathlib import Path
1414
from typing import Optional
1515

16-
from unstructured.partition.utils.constants import OCR_AGENT_TESSERACT
16+
from unstructured.partition.utils.constants import OCR_AGENT_TESSERACT, STT_AGENT_WHISPER
1717

1818

1919
@lru_cache(maxsize=1)
@@ -124,8 +124,6 @@ def STT_AGENT_CACHE_SIZE(self) -> int:
124124
@property
125125
def STT_AGENT(self) -> str:
126126
"""Speech-to-text agent module for partition_audio (e.g. Whisper)."""
127-
from unstructured.partition.utils.constants import STT_AGENT_WHISPER
128-
129127
return self._get_string("STT_AGENT", STT_AGENT_WHISPER)
130128

131129
@property
@@ -149,12 +147,22 @@ def WHISPER_DEVICE(self) -> str:
149147

150148
@property
151149
def WHISPER_FP16(self) -> bool:
152-
"""Use FP16 for Whisper transcription when True (default).
150+
"""Use FP16 for Whisper transcription.
153151
154-
FP16 gives roughly 2x GPU speedup on CUDA with minimal quality impact.
155-
Set WHISPER_FP16=false to disable (e.g. for CPU or compatibility).
152+
FP16 gives roughly 2x GPU speedup on CUDA with minimal quality impact, but is
153+
unsupported on CPU and will raise a RuntimeError there. The default is auto-detected:
154+
True when a CUDA GPU is available, False otherwise.
155+
Set WHISPER_FP16=true/false explicitly to override.
156156
"""
157-
return self._get_bool("WHISPER_FP16", True)
157+
env_val = self._get_string("WHISPER_FP16")
158+
if env_val:
159+
return env_val.lower() in ("true", "1", "t")
160+
try:
161+
import torch
162+
163+
return bool(torch.cuda.is_available())
164+
except ImportError:
165+
return False
158166

159167
@property
160168
def EXTRACT_IMAGE_BLOCK_CROP_HORIZONTAL_PAD(self) -> int:

0 commit comments

Comments
 (0)