Skip to content

Commit 4865edb

Browse files
committed
Add Engine helpers for model path detection
Introduce Engine helpers (is_pytorch_model_path, is_tensorflow_model_dir_path) and centralize model-file/TF-dir detection. Replace the old is_model_file usage with these helpers across settings_store and main_window, fixing a bug in suffix checking and ensuring .pb TensorFlow models are validated via their parent directory. Remove legacy is_model_file from utils. Also update camera UI to expect dlg.request_scan_cancel instead of _on_scan_cancel, refine scan-cancel test synchronization, and add/adjust unit tests to cover the new Engine detection logic.
1 parent c7b11c9 commit 4865edb

File tree

8 files changed

+164
-53
lines changed

8 files changed

+164
-53
lines changed

dlclivegui/gui/camera_config/ui_blocks.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,8 @@ def build_available_cameras_group(dlg: CameraConfigDialog) -> QGroupBox:
193193
dlg.scan_cancel_btn.setIcon(dlg.style().standardIcon(QStyle.StandardPixmap.SP_BrowserStop))
194194
dlg.scan_cancel_btn.setVisible(False)
195195

196-
# The original UI block connects cancel here; preserve that.
197-
# dlg must provide _on_scan_cancel
198-
if hasattr(dlg, "_on_scan_cancel"):
196+
# dlg must provide request_scan_cancel()
197+
if hasattr(dlg, "request_scan_cancel"):
199198
dlg.scan_cancel_btn.clicked.connect(dlg.request_scan_cancel) # type: ignore[attr-defined]
200199

201200
available_layout.addWidget(dlg.scan_cancel_btn)

dlclivegui/gui/main_window.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ def _parse_json(self, value: str) -> dict:
876876

877877
def _dlc_settings_from_ui(self) -> DLCProcessorSettings:
878878
model_path = self.model_path_edit.text().strip()
879-
if Path(model_path).exists() and Path(model_path).suffix in (".pb"):
879+
if Path(model_path).exists() and Path(model_path).suffix == ".pb":
880880
# IMPORTANT NOTE: DLClive expects a directory for TensorFlow models,
881881
# so if user selects a .pb file, we should pass the parent directory to DLCLive
882882
model_path = str(Path(model_path).parent)
@@ -1004,7 +1004,10 @@ def _action_browse_model(self) -> None:
10041004
return
10051005

10061006
try:
1007-
DLCLiveProcessor.get_model_backend(str(file_path))
1007+
if file_path.suffix == ".pb":
1008+
# For TensorFlow, DLCLive expects a directory, so we pass the parent directory for validation
1009+
model_check_path = file_path.parent
1010+
DLCLiveProcessor.get_model_backend(str(model_check_path))
10081011
except FileNotFoundError as e:
10091012
QMessageBox.warning(self, "Model selection error", str(e))
10101013
return

dlclivegui/temp/engine.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,28 @@
22
from pathlib import Path
33

44

5+
# TODO @C-Achard decide if this moves to utils,
6+
# or if we update dlclive.Engine to have these methods and use that instead of a separate enum here.
7+
# The latter would be more cohesive but also creates a dependency from utils to dlclive,
8+
# pending release of dlclive
59
class Engine(Enum):
610
TENSORFLOW = "tensorflow"
711
PYTORCH = "pytorch"
812

13+
@staticmethod
14+
def is_pytorch_model_path(model_path: str | Path) -> bool:
15+
path = Path(model_path)
16+
return path.is_file() and path.suffix.lower() in (".pt", ".pth")
17+
18+
@staticmethod
19+
def is_tensorflow_model_dir_path(model_path: str | Path) -> bool:
20+
path = Path(model_path)
21+
if not path.is_dir():
22+
return False
23+
has_cfg = (path / "pose_cfg.yaml").is_file()
24+
has_pb = any(p.suffix.lower() == ".pb" for p in path.glob("*.pb"))
25+
return has_cfg and has_pb
26+
927
@classmethod
1028
def from_model_type(cls, model_type: str) -> "Engine":
1129
if model_type.lower() == "pytorch":
@@ -23,13 +41,10 @@ def from_model_path(cls, model_path: str | Path) -> "Engine":
2341
raise FileNotFoundError(f"Model path does not exist: {model_path}")
2442

2543
if path.is_dir():
26-
has_cfg = (path / "pose_cfg.yaml").is_file()
27-
# has_cfg is DLClive specific and is considered a requirement for TF live models.
28-
has_pb = any(p.suffix == ".pb" for p in path.glob("*.pb"))
29-
if has_cfg and has_pb:
44+
if cls.is_tensorflow_model_dir_path(path):
3045
return cls.TENSORFLOW
3146
elif path.is_file():
32-
if path.suffix in (".pt", ".pth"):
47+
if cls.is_pytorch_model_path(path):
3348
return cls.PYTORCH
3449

3550
raise ValueError(f"Could not determine engine from model path: {model_path}")

dlclivegui/utils/settings_store.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from PySide6.QtCore import QSettings
88

99
from ..config import ApplicationSettings
10-
from .utils import is_model_file
10+
from ..temp import Engine # type: ignore # TODO use main package enum when released
1111

1212
logger = logging.getLogger(__name__)
1313

@@ -132,7 +132,7 @@ def load_last(self) -> str | None:
132132
try:
133133
pp = Path(path)
134134
# Accept a valid model *file*
135-
if pp.is_file() and is_model_file(str(pp)):
135+
if pp.is_file() and (Engine.is_pytorch_model_path(pp) or Engine.is_tensorflow_model_dir_path(pp.parent)):
136136
return str(pp)
137137
except Exception:
138138
logger.debug("Last model path not valid/usable: %s", path)
@@ -172,10 +172,12 @@ def save_if_valid(self, path: str) -> None:
172172
self._settings.setValue("dlc/last_model_dir", model_dir_norm)
173173

174174
# Persist model path if it is a valid model file, or a TF model directory
175-
if p.is_file() and is_model_file(str(p)):
175+
if Engine.is_pytorch_model_path(p):
176176
self._settings.setValue("dlc/last_model_path", str(p))
177-
elif p.is_dir() and self._looks_like_tf_model_dir(p):
177+
elif p.parent.is_dir() and Engine.is_tensorflow_model_dir_path(p.parent):
178178
self._settings.setValue("dlc/last_model_path", str(p))
179+
# elif p.is_dir() and Engine.is_tensorflow_model_dir_path(p):
180+
# self._settings.setValue("dlc/last_model_path", str(p))
179181

180182
except Exception:
181183
logger.debug("Failed to save model path: %s", path, exc_info=True)
@@ -204,9 +206,9 @@ def resolve(self, config_path: str | None) -> str:
204206
if cfg:
205207
try:
206208
p = Path(cfg)
207-
if p.is_file() and is_model_file(cfg):
209+
if p.is_file() and Engine.is_pytorch_model_path(p):
208210
return cfg
209-
if p.is_dir() and self._looks_like_tf_model_dir(p):
211+
if p.is_dir() and Engine.is_tensorflow_model_dir_path(p):
210212
return cfg
211213
except Exception:
212214
logger.debug("Config path not usable: %s", cfg)

dlclivegui/utils/utils.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,9 @@
88
from datetime import datetime
99
from pathlib import Path
1010

11-
SUPPORTED_MODELS = [".pt", ".pth", ".pb"]
1211
_INVALID_CHARS = re.compile(r"[^A-Za-z0-9._-]+")
1312

1413

15-
def is_model_file(file_path: Path | str) -> bool:
16-
if not isinstance(file_path, Path):
17-
file_path = Path(file_path)
18-
if not file_path.is_file():
19-
return False
20-
return file_path.suffix.lower() in SUPPORTED_MODELS
21-
22-
2314
def sanitize_name(name: str, *, fallback: str = "session") -> str:
2415
"""Make a user-provided string safe for filesystem paths."""
2516
name = (name or "").strip()

tests/gui/camera_config/test_cam_dialog_e2e.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,11 +304,13 @@ def slow_detect(backend, max_devices=10, should_cancel=None, progress_cb=None, *
304304

305305
qtbot.mouseClick(dialog.scan_cancel_btn, Qt.LeftButton)
306306

307+
# scan_finished = UI stable, not necessarily worker fully stopped / controls unlocked
307308
with qtbot.waitSignal(dialog.scan_finished, timeout=3000):
308309
pass
309310

310-
assert dialog.refresh_btn.isEnabled()
311-
assert dialog.backend_combo.isEnabled()
311+
# Wait until scan controls are unlocked (worker finished)
312+
qtbot.waitUntil(lambda: dialog.refresh_btn.isEnabled(), timeout=3000)
313+
qtbot.waitUntil(lambda: dialog.backend_combo.isEnabled(), timeout=3000)
312314

313315

314316
@pytest.mark.gui

tests/utils/test_settings_store.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,19 +235,21 @@ def test_model_path_store_resolve_prefers_config_path_when_valid(tmp_path: Path)
235235
assert mps.resolve(str(model)) == str(model)
236236

237237

238-
def test_model_path_store_resolve_falls_back_to_persisted(tmp_path: Path):
238+
def test_model_path_store_resolve_falls_back_to_persisted_tf_dir(tmp_path: Path):
239239
settings = InMemoryQSettings()
240240
mps = store.ModelPathStore(settings=settings)
241241

242-
persisted = tmp_path / "persisted.pb"
243-
persisted.write_text("x")
244-
settings.setValue("dlc/last_model_path", str(persisted))
242+
tf_dir = tmp_path / "tf_model"
243+
tf_dir.mkdir()
244+
(tf_dir / "pose_cfg.yaml").write_text("cfg: 1\n")
245+
(tf_dir / "graph.pb").write_text("pb")
246+
247+
settings.setValue("dlc/last_model_path", str(tf_dir / "graph.pb"))
245248

246-
# invalid config path
247249
bad = tmp_path / "notamodel.onnx"
248250
bad.write_text("x")
249251

250-
assert mps.resolve(str(bad)) == str(persisted)
252+
assert mps.resolve(str(bad)) == str(tf_dir / "graph.pb")
251253

252254

253255
def test_model_path_store_resolve_returns_empty_when_nothing_valid(tmp_path: Path):

tests/utils/test_utils.py

Lines changed: 117 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,138 @@
1+
from __future__ import annotations
2+
13
from pathlib import Path
24

35
import pytest
46

57
import dlclivegui.utils.utils as u
8+
from dlclivegui.temp.engine import Engine # e.g. dlclivegui/utils/engine.py
69

710
pytestmark = pytest.mark.unit
811

912

13+
# NOTE @C-Achard: These tests are currently in test_utils.py for convenience,
14+
# but we may want to use dlclive.Engine directly
15+
# and possibly move these tests to dlclive's test suite
1016
# -----------------------------
11-
# is_model_file
17+
# Engine.from_model_type
1218
# -----------------------------
13-
@pytest.mark.unit
14-
def test_is_model_file_true_for_supported_extensions(tmp_path: Path):
15-
for ext in [".pt", ".pth", ".pb"]:
16-
p = tmp_path / f"model{ext}"
17-
p.write_text("x")
18-
assert u.is_model_file(p) is True
19-
assert u.is_model_file(str(p)) is True # also accepts str
19+
@pytest.mark.parametrize(
20+
"inp, expected",
21+
[
22+
("pytorch", Engine.PYTORCH),
23+
("PYTORCH", Engine.PYTORCH),
24+
("tensorflow", Engine.TENSORFLOW),
25+
("TensorFlow", Engine.TENSORFLOW),
26+
("base", Engine.TENSORFLOW),
27+
("tensorrt", Engine.TENSORFLOW),
28+
("lite", Engine.TENSORFLOW),
29+
],
30+
)
31+
def test_engine_from_model_type(inp: str, expected: Engine):
32+
assert Engine.from_model_type(inp) == expected
33+
34+
35+
def test_engine_from_model_type_unknown():
36+
with pytest.raises(ValueError):
37+
Engine.from_model_type("onnx")
2038

21-
# case-insensitive
22-
p2 = tmp_path / "MODEL.PT"
23-
p2.write_text("x")
24-
assert u.is_model_file(p2) is True
2539

40+
# -----------------------------
41+
# Engine.is_pytorch_model_path
42+
# -----------------------------
43+
@pytest.mark.parametrize("ext", [".pt", ".pth"])
44+
def test_engine_is_pytorch_model_path_true(tmp_path: Path, ext: str):
45+
p = tmp_path / f"model{ext}"
46+
p.write_text("x")
47+
assert Engine.is_pytorch_model_path(p) is True
48+
assert Engine.is_pytorch_model_path(str(p)) is True
2649

27-
@pytest.mark.unit
28-
def test_is_model_file_false_for_missing_or_dir(tmp_path: Path):
29-
missing = tmp_path / "missing.pt"
30-
assert u.is_model_file(missing) is False
3150

51+
def test_engine_is_pytorch_model_path_false_for_missing(tmp_path: Path):
52+
p = tmp_path / "missing.pt"
53+
assert Engine.is_pytorch_model_path(p) is False
54+
55+
56+
def test_engine_is_pytorch_model_path_false_for_dir(tmp_path: Path):
3257
d = tmp_path / "model.pt"
3358
d.mkdir()
34-
assert u.is_model_file(d) is False
59+
assert Engine.is_pytorch_model_path(d) is False
60+
61+
62+
def test_engine_is_pytorch_model_path_case_insensitive(tmp_path: Path):
63+
# only include if you applied the .lower() patch
64+
p = tmp_path / "MODEL.PT"
65+
p.write_text("x")
66+
assert Engine.is_pytorch_model_path(p) is True
67+
68+
69+
# -----------------------------
70+
# Engine.is_tensorflow_model_dir_path
71+
# -----------------------------
72+
def _make_tf_dir(tmp_path: Path, *, with_cfg: bool = True, with_pb: bool = True, pb_name: str = "graph.pb") -> Path:
73+
d = tmp_path / "tf_model"
74+
d.mkdir()
75+
if with_cfg:
76+
(d / "pose_cfg.yaml").write_text("cfg: 1\n")
77+
if with_pb:
78+
(d / pb_name).write_text("pbdata")
79+
return d
80+
81+
82+
def test_engine_is_tensorflow_model_dir_path_true(tmp_path: Path):
83+
d = _make_tf_dir(tmp_path, with_cfg=True, with_pb=True)
84+
assert Engine.is_tensorflow_model_dir_path(d) is True
85+
assert Engine.is_tensorflow_model_dir_path(str(d)) is True
86+
87+
88+
def test_engine_is_tensorflow_model_dir_path_false_missing_cfg(tmp_path: Path):
89+
d = _make_tf_dir(tmp_path, with_cfg=False, with_pb=True)
90+
assert Engine.is_tensorflow_model_dir_path(d) is False
91+
92+
93+
def test_engine_is_tensorflow_model_dir_path_false_missing_pb(tmp_path: Path):
94+
d = _make_tf_dir(tmp_path, with_cfg=True, with_pb=False)
95+
assert Engine.is_tensorflow_model_dir_path(d) is False
96+
97+
98+
def test_engine_is_tensorflow_model_dir_path_case_insensitive_pb(tmp_path: Path):
99+
# only include if you applied the .lower() patch for pb suffix
100+
d = _make_tf_dir(tmp_path, with_cfg=True, with_pb=True, pb_name="GRAPH.PB")
101+
assert Engine.is_tensorflow_model_dir_path(d) is True
102+
103+
104+
# -----------------------------
105+
# Engine.from_model_path
106+
# -----------------------------
107+
def test_engine_from_model_path_missing_raises(tmp_path: Path):
108+
missing = tmp_path / "does_not_exist.pt"
109+
with pytest.raises(FileNotFoundError):
110+
Engine.from_model_path(missing)
111+
112+
113+
def test_engine_from_model_path_pytorch_file(tmp_path: Path):
114+
p = tmp_path / "net.pth"
115+
p.write_text("x")
116+
assert Engine.from_model_path(p) == Engine.PYTORCH
117+
118+
119+
def test_engine_from_model_path_tensorflow_dir(tmp_path: Path):
120+
d = _make_tf_dir(tmp_path, with_cfg=True, with_pb=True)
121+
assert Engine.from_model_path(d) == Engine.TENSORFLOW
122+
123+
124+
def test_engine_from_model_path_dir_not_tf_raises(tmp_path: Path):
125+
d = tmp_path / "some_dir"
126+
d.mkdir()
127+
with pytest.raises(ValueError):
128+
Engine.from_model_path(d)
129+
35130

36-
bad = tmp_path / "model.onnx"
37-
bad.write_text("x")
38-
assert u.is_model_file(bad) is False
131+
def test_engine_from_model_path_file_not_pytorch_raises(tmp_path: Path):
132+
p = tmp_path / "model.pb"
133+
p.write_text("x") # PB file alone is not a TF dir
134+
with pytest.raises(ValueError):
135+
Engine.from_model_path(p)
39136

40137

41138
# -----------------------------

0 commit comments

Comments
 (0)