File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 55Licensed under GNU Lesser General Public License v3.0
66"""
77
8+ # Check which backends are installed and get available backends
9+ # (Emits a warning if neither TensorFlow nor PyTorch is installed)
10+ from dlclive .utils import get_available_backends
11+ _AVAILABLE_BACKENDS = get_available_backends ()
12+
813from dlclive .display import Display
914from dlclive .dlclive import DLCLive
1015from dlclive .processor .processor import Processor
Original file line number Diff line number Diff line change 1616from dlclive .utils import download_file
1717from dlclive .benchmark import benchmark_videos
1818from dlclive .engine import Engine
19+ from dlclive .utils import get_available_backends
1920
2021MODEL_NAME = "superanimal_quadruped"
2122SNAPSHOT_NAME = "snapshot-700000.pb"
@@ -93,4 +94,15 @@ def main():
9394
9495
9596if __name__ == "__main__" :
97+
98+ # Get available backends (emits a warning if neither TensorFlow nor PyTorch is installed)
99+ available_backends : list [Engine ] = get_available_backends ()
100+ print (f"Available backends: { [b .value for b in available_backends ]} " )
101+
102+ # TODO: JR add support for PyTorch in check_install.py (requires some exported pytorch model to be downloaded)
103+ if not Engine .TENSORFLOW in available_backends :
104+ raise NotImplementedError (
105+ "TensorFlow is not installed. Currently check_install.py only supports testing the TensorFlow installation."
106+ )
107+
96108 main ()
Original file line number Diff line number Diff line change 1212import urllib .error
1313
1414from dlclive .exceptions import DLCLiveWarning
15+ from dlclive .engine import Engine
1516
1617try :
1718 import skimage
@@ -214,6 +215,41 @@ def decode_fourcc(cc):
214215 return decoded
215216
216217
218+ def get_available_backends () -> list [Engine ]:
219+ """
220+ Check which backends (TensorFlow or PyTorch) are installed.
221+
222+ Returns:
223+ list[str]: List of installed backends. Possible values: ["tensorflow"], ["pytorch"],
224+ or ["tensorflow", "pytorch"]. Returns an empty list if neither is installed.
225+
226+ Warns:
227+ DLCLiveWarning: If neither TensorFlow nor PyTorch is installed.
228+ """
229+ backends = []
230+
231+ try :
232+ import tensorflow
233+ backends .append (Engine .TENSORFLOW )
234+ except (ImportError , ModuleNotFoundError ):
235+ pass
236+
237+ try :
238+ import torch
239+ backends .append (Engine .PYTORCH )
240+ except (ImportError , ModuleNotFoundError ):
241+ pass
242+
243+ if not backends :
244+ warnings .warn (
245+ "Neither TensorFlow nor PyTorch is installed. One of these is required to use DLCLive!"
246+ "Install with: pip install deeplabcut-live[tf] or pip install deeplabcut-live[pytorch]" ,
247+ DLCLiveWarning ,
248+ )
249+
250+ return backends
251+
252+
217253def download_file (url : str , filepath : str , chunk_size : int = 8192 ) -> None :
218254 """
219255 Download a file from a URL with progress bar and error handling.
Original file line number Diff line number Diff line change 44from dlclive .engine import Engine
55
66
7+ # TODO: JR include separate functional tests for torch and tf backends
78@pytest .mark .functional
89def test_benchmark_script_runs (tmp_path ):
910 datafolder = tmp_path / "Data-DLC-live-benchmark"
You can’t perform that action at this time.
0 commit comments