-
-
Notifications
You must be signed in to change notification settings - Fork 183
Expand file tree
/
Copy pathtest_gpu_runtime_setup.py
More file actions
31 lines (22 loc) · 1.45 KB
/
test_gpu_runtime_setup.py
File metadata and controls
31 lines (22 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from unittest.mock import patch
from audio_separator.separator.separator import Separator
def test_setup_accelerated_inferencing_device_preloads_onnxruntime_dependencies():
separator = Separator(info_only=True)
system_info = object()
with patch.object(separator, "get_system_info", return_value=system_info), patch.object(separator, "check_ffmpeg_installed"), patch.object(
separator, "log_onnxruntime_packages"
), patch("audio_separator.separator.separator.ort.preload_dlls", create=True) as mock_preload, patch.object(separator, "setup_torch_device") as mock_setup:
separator.setup_accelerated_inferencing_device()
mock_preload.assert_called_once_with()
mock_setup.assert_called_once_with(system_info)
def test_setup_accelerated_inferencing_device_continues_when_preload_fails():
separator = Separator(info_only=True)
system_info = object()
with patch.object(separator, "get_system_info", return_value=system_info), patch.object(separator, "check_ffmpeg_installed"), patch.object(
separator, "log_onnxruntime_packages"
), patch("audio_separator.separator.separator.ort.preload_dlls", side_effect=RuntimeError("boom"), create=True), patch.object(
separator, "setup_torch_device"
) as mock_setup, patch.object(separator.logger, "warning") as mock_warning:
separator.setup_accelerated_inferencing_device()
mock_setup.assert_called_once_with(system_info)
mock_warning.assert_called_once()