Skip to content

Commit 7d13ffa

Browse files
Ensure CUDA availability checks and mock pyannote.audio imports for tests
1 parent 1aadce7 commit 7d13ffa

2 files changed

Lines changed: 13 additions & 5 deletions

File tree

src/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,13 @@
8282

8383
def _reset_vram(device: str) -> None:
8484
"""Reset peak VRAM stats for the given device before a stage begins."""
85-
if _has_torch and device.startswith("cuda"):
85+
if _has_torch and device.startswith("cuda") and _torch.cuda.is_available():
8686
_torch.cuda.reset_peak_memory_stats(device)
8787

8888

8989
def _read_vram(device: str) -> int | None:
9090
"""Return peak VRAM allocated (bytes) since the last reset, or None on CPU."""
91-
if _has_torch and device.startswith("cuda"):
91+
if _has_torch and device.startswith("cuda") and _torch.cuda.is_available():
9292
return _torch.cuda.max_memory_allocated(device)
9393
return None
9494

tests/test_diarizer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Unit tests for src.diarizer — all Pyannote calls are mocked."""
22

3+
import sys
34
from pathlib import Path
45
from unittest.mock import MagicMock, patch
56

@@ -37,8 +38,13 @@ def test_loads_and_moves_to_device(self):
3738
mock_pipeline_cls = MagicMock()
3839
mock_pipeline_cls.from_pretrained.return_value = mock_pipeline
3940

40-
# Pipeline is imported locally inside load_pipeline, so patch at the source.
41-
with patch("pyannote.audio.Pipeline", mock_pipeline_cls):
41+
# `load_pipeline` does a local `from pyannote.audio import Pipeline`.
42+
# Patching the live attribute triggers the real module import, which pulls in
43+
# matplotlib and crashes in headless CI. Pre-populate sys.modules instead so
44+
# the local import inside load_pipeline picks up the mock without any I/O.
45+
mock_pa = MagicMock()
46+
mock_pa.Pipeline = mock_pipeline_cls
47+
with patch.dict(sys.modules, {"pyannote.audio": mock_pa}):
4248
result = load_pipeline(DEFAULT_MODEL, "cuda", "hf_fake")
4349

4450
mock_pipeline_cls.from_pretrained.assert_called_once_with(DEFAULT_MODEL)
@@ -49,7 +55,9 @@ def test_raises_diarization_error_on_failure(self):
4955
mock_pipeline_cls = MagicMock()
5056
mock_pipeline_cls.from_pretrained.side_effect = RuntimeError("model not found")
5157

52-
with patch("pyannote.audio.Pipeline", mock_pipeline_cls):
58+
mock_pa = MagicMock()
59+
mock_pa.Pipeline = mock_pipeline_cls
60+
with patch.dict(sys.modules, {"pyannote.audio": mock_pa}):
5361
with pytest.raises(DiarizationError, match="Failed to load Pyannote pipeline"):
5462
load_pipeline(DEFAULT_MODEL, "cuda", "hf_fake")
5563

0 commit comments

Comments
 (0)