Skip to content

Commit 17a3a52

Browse files
committed
Add centralized get_available_backends function.
The function is called in __init__ and emits a warning if neither tf nor torch are installed. Added todo comments in tests that should test both tensorflow and pytorch installation.
1 parent 5be9409 commit 17a3a52

4 files changed

Lines changed: 54 additions & 0 deletions

File tree

dlclive/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
Licensed under GNU Lesser General Public License v3.0
66
"""
77

8+
# Check which backends are installed and get available backends
9+
# (Emits a warning if neither TensorFlow nor PyTorch is installed)
10+
from dlclive.utils import get_available_backends
11+
_AVAILABLE_BACKENDS = get_available_backends()
12+
813
from dlclive.display import Display
914
from dlclive.dlclive import DLCLive
1015
from dlclive.processor.processor import Processor

dlclive/check_install/check_install.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from dlclive.utils import download_file
1717
from dlclive.benchmark import benchmark_videos
1818
from dlclive.engine import Engine
19+
from dlclive.utils import get_available_backends
1920

2021
MODEL_NAME = "superanimal_quadruped"
2122
SNAPSHOT_NAME = "snapshot-700000.pb"
@@ -93,4 +94,15 @@ def main():
9394

9495

9596
if __name__ == "__main__":
97+
98+
# Get available backends (emits a warning if neither TensorFlow nor PyTorch is installed)
99+
available_backends: list[Engine] = get_available_backends()
100+
print(f"Available backends: {[b.value for b in available_backends]}")
101+
102+
# TODO: JR add support for PyTorch in check_install.py (requires some exported pytorch model to be downloaded)
103+
if not Engine.TENSORFLOW in available_backends:
104+
raise NotImplementedError(
105+
"TensorFlow is not installed. Currently check_install.py only supports testing the TensorFlow installation."
106+
)
107+
96108
main()

dlclive/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import urllib.error
1313

1414
from dlclive.exceptions import DLCLiveWarning
15+
from dlclive.engine import Engine
1516

1617
try:
1718
import skimage
@@ -214,6 +215,41 @@ def decode_fourcc(cc):
214215
return decoded
215216

216217

218+
def get_available_backends() -> list[Engine]:
219+
"""
220+
Check which backends (TensorFlow or PyTorch) are installed.
221+
222+
Returns:
223+
list[str]: List of installed backends. Possible values: ["tensorflow"], ["pytorch"],
224+
or ["tensorflow", "pytorch"]. Returns an empty list if neither is installed.
225+
226+
Warns:
227+
DLCLiveWarning: If neither TensorFlow nor PyTorch is installed.
228+
"""
229+
backends = []
230+
231+
try:
232+
import tensorflow
233+
backends.append(Engine.TENSORFLOW)
234+
except (ImportError, ModuleNotFoundError):
235+
pass
236+
237+
try:
238+
import torch
239+
backends.append(Engine.PYTORCH)
240+
except (ImportError, ModuleNotFoundError):
241+
pass
242+
243+
if not backends:
244+
warnings.warn(
245+
"Neither TensorFlow nor PyTorch is installed. One of these is required to use DLCLive!"
246+
"Install with: pip install deeplabcut-live[tf] or pip install deeplabcut-live[pytorch]",
247+
DLCLiveWarning,
248+
)
249+
250+
return backends
251+
252+
217253
def download_file(url: str, filepath: str, chunk_size: int = 8192) -> None:
218254
"""
219255
Download a file from a URL with progress bar and error handling.

tests/test_benchmark_script.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from dlclive.engine import Engine
55

66

7+
# TODO: JR include separate functional tests for torch and tf backends
78
@pytest.mark.functional
89
def test_benchmark_script_runs(tmp_path):
910
datafolder = tmp_path / "Data-DLC-live-benchmark"

0 commit comments

Comments
 (0)