Skip to content

Commit 31e09d8

Browse files
committed
Normalize model_type and fix model paths
Coerce DLCProcessorSettings.model_type to a lowercase string (accepting Enum or string inputs) and validate allowed backends (pytorch, tensorflow). Update UI to handle TensorFlow .pb models by using the parent directory for DLCLive, restrict file dialog to existing files, add existence checks and backend detection when selecting a model. Improve ModelPathStore: robust path normalization, separate helpers for existing file/dir checks, smarter save/load/resolve logic, and better start-dir/suggest-file heuristics. Minor cleanup: remove duplicate import and clarify a TF model detection comment in engine.
1 parent 48b274c commit 31e09d8

File tree

5 files changed

+195
-77
lines changed

5 files changed

+195
-77
lines changed

dlclivegui/config.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
from __future__ import annotations
33

44
import json
5+
from enum import Enum
56
from pathlib import Path
67
from typing import Any, Literal
78

89
from pydantic import BaseModel, Field, field_validator, model_validator
910

10-
from dlclivegui.temp import Engine
11-
1211
Rotation = Literal[0, 90, 180, 270]
1312
TileLayout = Literal["auto", "2x2", "1x4", "4x1"]
1413
Precision = Literal["FP32", "FP16"]
@@ -241,14 +240,46 @@ class DLCProcessorSettings(BaseModel):
241240
resize: float = Field(default=1.0, gt=0)
242241
precision: Precision = "FP32"
243242
additional_options: dict[str, Any] = Field(default_factory=dict)
244-
model_type: Engine = Engine.PYTORCH
243+
model_type: str = "pytorch"
245244
single_animal: bool = True
246245

247246
@field_validator("dynamic", mode="before")
248247
@classmethod
249248
def _coerce_dynamic(cls, v):
250249
return DynamicCropModel.from_tupleish(v)
251250

251+
@field_validator("model_type", mode="before")
252+
@classmethod
253+
def _coerce_model_type(cls, v):
254+
"""
255+
Accept:
256+
- "pytorch"/"tensorflow"/etc as strings
257+
- Enum instances (e.g. Engine.PYTORCH) and store their .value
258+
Always return a lowercase string.
259+
"""
260+
if v is None or v == "":
261+
return "pytorch"
262+
263+
# If caller passed Engine enum or any Enum, use its value
264+
if isinstance(v, Enum):
265+
v = v.value
266+
267+
# If caller passed something with a `.value` attribute (defensive)
268+
if not isinstance(v, str) and hasattr(v, "value"):
269+
v = v.value
270+
271+
if not isinstance(v, str):
272+
raise TypeError(f"model_type must be a string or Enum, got {type(v)!r}")
273+
274+
v = v.strip().lower()
275+
276+
# Optional: enforce allowed values
277+
allowed = {"pytorch", "tensorflow"}
278+
if v not in allowed:
279+
raise ValueError(f"Unknown model type: {v!r}. Allowed: {sorted(allowed)}")
280+
281+
return v
282+
252283

253284
class BoundingBoxSettings(BaseModel):
254285
enabled: bool = False

dlclivegui/gui/main_window.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,10 @@ def _parse_json(self, value: str) -> dict:
876876

877877
def _dlc_settings_from_ui(self) -> DLCProcessorSettings:
878878
model_path = self.model_path_edit.text().strip()
879+
if Path(model_path).exists() and Path(model_path).suffix in (".pb"):
880+
# IMPORTANT NOTE: DLClive expects a directory for TensorFlow models,
881+
# so if user selects a .pb file, we should pass the parent directory to DLCLive
882+
model_path = str(Path(model_path).parent)
879883
if model_path == "":
880884
raise ValueError("Model path cannot be empty. Please enter a valid path to a DLCLive model file.")
881885
try:
@@ -976,13 +980,12 @@ def _action_browse_model(self) -> None:
976980
preselect = self._model_path_store.suggest_selected_file()
977981

978982
dlg = QFileDialog(self, "Select DLCLive model file")
979-
dlg.setFileMode(QFileDialog.FileMode.AnyFile)
983+
dlg.setFileMode(QFileDialog.FileMode.ExistingFile)
980984
dlg.setNameFilters(
981985
[
982986
"Model files (*.pt *.pth)",
983987
"PyTorch models (*.pt *.pth)",
984-
# "TensorFlow models (*.pb)",
985-
"TensorFlow model directory (*.*)",
988+
"TensorFlow models (*.pb)",
986989
]
987990
)
988991
dlg.setDirectory(start_dir)
@@ -995,7 +998,20 @@ def _action_browse_model(self) -> None:
995998
selected = dlg.selectedFiles()
996999
if not selected:
9971000
return
998-
file_path = selected[0]
1001+
file_path = Path(selected[0]).expanduser()
1002+
if not file_path.exists():
1003+
QMessageBox.warning(self, "File not found", f"The selected file does not exist:\n{file_path}")
1004+
return
1005+
1006+
try:
1007+
DLCLiveProcessor.get_model_backend(str(file_path))
1008+
except FileNotFoundError as e:
1009+
QMessageBox.warning(self, "Model selection error", str(e))
1010+
return
1011+
except ValueError as e:
1012+
QMessageBox.warning(self, "Model selection error", str(e))
1013+
return
1014+
file_path = str(file_path)
9991015
self.model_path_edit.setText(file_path)
10001016

10011017
# Persist model path + directory

dlclivegui/services/dlc_processor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from PySide6.QtCore import QObject, Signal
1717

1818
from dlclivegui.config import DLCProcessorSettings
19-
20-
# from dlclivegui.config import DLCProcessorSettings
2119
from dlclivegui.processors.processor_utils import instantiate_from_scan
2220
from dlclivegui.temp import Engine # type: ignore # TODO use main package enum when released
2321

dlclivegui/temp/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def from_model_path(cls, model_path: str | Path) -> "Engine":
2424

2525
if path.is_dir():
2626
has_cfg = (path / "pose_cfg.yaml").is_file()
27+
# has_cfg is DLClive specific and is considered a requirement for TF live models.
2728
has_pb = any(p.suffix == ".pb" for p in path.glob("*.pb"))
2829
if has_cfg and has_pb:
2930
return cls.TENSORFLOW

dlclivegui/utils/settings_store.py

Lines changed: 140 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# dlclivegui/utils/settings_store.py
2+
from __future__ import annotations
3+
24
import logging
35
from pathlib import Path
46

@@ -70,124 +72,194 @@ class ModelPathStore:
7072
def __init__(self, settings: QSettings | None = None):
7173
self._settings = settings or QSettings("DeepLabCut", "DLCLiveGUI")
7274

73-
def _norm(self, p: str | None) -> str | None:
75+
# -------------------------
76+
# Normalization helpers
77+
# -------------------------
78+
def _as_path(self, p: str | None) -> Path | None:
79+
"""Best-effort conversion to Path (expand ~, interpret '.' as cwd)."""
7480
if not p:
7581
return None
82+
s = str(p).strip()
83+
if not s:
84+
return None
7685
try:
77-
return str(Path(p).expanduser().resolve())
86+
pp = Path(s).expanduser()
87+
if s in (".", "./"):
88+
pp = Path.cwd()
89+
return pp
7890
except Exception:
79-
logger.debug("Failed to normalize path: %s", p)
91+
logger.debug("Failed to parse path: %s", p)
92+
return None
93+
94+
def _norm_existing_dir(self, p: str | None) -> str | None:
95+
"""Return an absolute, resolved existing directory path, else None."""
96+
pp = self._as_path(p)
97+
if pp is None:
98+
return None
99+
try:
100+
# If a file was given, use its parent directory
101+
if pp.exists() and pp.is_file():
102+
pp = pp.parent
103+
104+
if pp.exists() and pp.is_dir():
105+
return str(pp.resolve())
106+
except Exception:
107+
logger.debug("Failed to normalize directory: %s", p)
108+
return None
109+
110+
def _norm_existing_path(self, p: str | None) -> str | None:
111+
"""Return an absolute, resolved existing path (file or dir), else None."""
112+
pp = self._as_path(p)
113+
if pp is None:
80114
return None
115+
try:
116+
if pp.exists():
117+
return str(pp.resolve())
118+
except Exception:
119+
logger.debug("Failed to normalize path: %s", p)
120+
return None
81121

122+
# -------------------------
123+
# Load
124+
# -------------------------
82125
def load_last(self) -> str | None:
126+
"""Return last model path if it still exists and looks usable."""
83127
val = self._settings.value("dlc/last_model_path")
84-
path = self._norm(str(val)) if val else None
128+
path = self._norm_existing_path(str(val)) if val else None
85129
if not path:
86130
return None
131+
87132
try:
88-
return path if is_model_file(path) else None
133+
pp = Path(path)
134+
# Accept a valid model *file*
135+
if pp.is_file() and is_model_file(str(pp)):
136+
return str(pp)
89137
except Exception:
90-
logger.debug("Last model path is not a valid model file: %s", path)
91-
return None
138+
logger.debug("Last model path not valid/usable: %s", path)
139+
140+
return None
92141

93142
def load_last_dir(self) -> str | None:
143+
"""Return last directory if it still exists and is a directory."""
94144
val = self._settings.value("dlc/last_model_dir")
95-
d = self._norm(str(val)) if val else None
96-
if not d:
97-
return None
98-
try:
99-
p = Path(d)
100-
return str(p) if p.exists() and p.is_dir() else None
101-
except Exception:
102-
logger.debug("Last model dir is not a valid directory: %s", d)
103-
return None
145+
d = self._norm_existing_dir(str(val)) if val else None
146+
return d
104147

148+
# -------------------------
149+
# Save
150+
# -------------------------
105151
def save_if_valid(self, path: str) -> None:
106-
"""Save last model *file* if it looks valid, and always save its directory."""
107-
path = self._norm(path) or ""
108-
if not path:
152+
"""
153+
Save last model path if it looks valid/usable, and always save its directory.
154+
- For files: always save parent directory.
155+
- For directories: save directory itself if it looks like a TF model dir.
156+
"""
157+
norm = self._norm_existing_path(path)
158+
if not norm:
109159
return
160+
110161
try:
111-
parent = str(Path(path).parent)
112-
self._settings.setValue("dlc/last_model_dir", parent)
162+
p = Path(norm)
163+
164+
# Always persist a *directory* that is safe for QFileDialog.setDirectory(...)
165+
if p.is_dir():
166+
model_dir = p
167+
else:
168+
model_dir = p.parent
169+
170+
model_dir_norm = self._norm_existing_dir(str(model_dir))
171+
if model_dir_norm:
172+
self._settings.setValue("dlc/last_model_dir", model_dir_norm)
173+
174+
# Persist model path if it is a valid model file, or a TF model directory
175+
if p.is_file() and is_model_file(str(p)):
176+
self._settings.setValue("dlc/last_model_path", str(p))
177+
elif p.is_dir() and self._looks_like_tf_model_dir(p):
178+
self._settings.setValue("dlc/last_model_path", str(p))
113179

114-
if is_model_file(path):
115-
self._settings.setValue("dlc/last_model_path", str(Path(path)))
116180
except Exception:
117-
logger.debug("Failed to save last model path: %s", path)
118-
pass
181+
logger.debug("Failed to save model path: %s", path, exc_info=True)
119182

120183
def save_last_dir(self, directory: str) -> None:
121-
directory = self._norm(directory) or ""
122-
if not directory:
184+
d = self._norm_existing_dir(directory)
185+
if not d:
123186
return
124187
try:
125-
p = Path(directory)
126-
if p.exists() and p.is_dir():
127-
self._settings.setValue("dlc/last_model_dir", str(p))
188+
self._settings.setValue("dlc/last_model_dir", d)
128189
except Exception:
129-
pass
190+
logger.debug("Failed to save last model dir: %s", d, exc_info=True)
130191

192+
# -------------------------
193+
# Resolve
194+
# -------------------------
131195
def resolve(self, config_path: str | None) -> str:
132-
"""Resolve the best model path to display in the UI."""
133-
config_path = self._norm(config_path)
134-
if config_path:
196+
"""
197+
Resolve the best model path to display in the UI.
198+
Preference:
199+
1) config_path if valid/usable
200+
2) persisted last model path if valid/usable
201+
3) empty
202+
"""
203+
cfg = self._norm_existing_path(config_path)
204+
if cfg:
135205
try:
136-
if is_model_file(config_path):
137-
return config_path
206+
p = Path(cfg)
207+
if p.is_file() and is_model_file(cfg):
208+
return cfg
209+
if p.is_dir() and self._looks_like_tf_model_dir(p):
210+
return cfg
138211
except Exception:
139-
logger.debug("Config path is not a valid model file: %s", config_path)
140-
pass
212+
logger.debug("Config path not usable: %s", cfg)
141213

142214
persisted = self.load_last()
143215
if persisted:
144-
try:
145-
if is_model_file(persisted):
146-
return persisted
147-
except Exception:
148-
pass
216+
return persisted
149217

150218
return ""
151219

152220
def suggest_start_dir(self, fallback_dir: str | None = None) -> str:
153-
"""Pick the best directory to start the file dialog in."""
221+
"""
222+
Pick the best directory to start file dialogs in.
223+
Guarantees: returns an existing absolute directory (never '.').
224+
"""
154225
# 1) last dir
155226
last_dir = self.load_last_dir()
156227
if last_dir:
157228
return last_dir
158229

159-
# 2) directory of last valid model file
160-
last_file = self.load_last()
161-
if last_file:
230+
# 2) directory of last valid model path
231+
last = self.load_last()
232+
if last:
162233
try:
163-
parent = Path(last_file).parent
164-
if parent.exists():
165-
return str(parent)
234+
p = Path(last)
235+
if p.is_file():
236+
parent = self._norm_existing_dir(str(p.parent))
237+
if parent:
238+
return parent
239+
elif p.is_dir():
240+
d = self._norm_existing_dir(str(p))
241+
if d:
242+
return d
166243
except Exception:
167-
logger.debug("Failed to get parent of last model file: %s", last_file)
168-
pass
244+
logger.debug("Failed to derive start dir from last model: %s", last)
169245

170-
# 3) fallback dir (config.model_directory) if valid
171-
if fallback_dir:
172-
try:
173-
p = Path(fallback_dir).expanduser()
174-
if p.exists() and p.is_dir():
175-
return str(p)
176-
except Exception:
177-
logger.debug("Fallback dir is not a valid directory: %s", fallback_dir)
178-
pass
246+
# 3) fallback dir (e.g. config.dlc.model_directory)
247+
fb = self._norm_existing_dir(fallback_dir)
248+
if fb:
249+
return fb
179250

180-
# 4) last resort: home
181-
return str(Path.home())
251+
# 4) last resort: cwd if exists else home
252+
cwd = self._norm_existing_dir(str(Path.cwd()))
253+
return cwd or str(Path.home())
182254

183255
def suggest_selected_file(self) -> str | None:
184-
"""Optional: return a file to preselect if it exists."""
185-
last_file = self.load_last()
186-
if not last_file:
256+
"""Return a file to preselect if it exists (only files, not directories)."""
257+
last = self.load_last()
258+
if not last:
187259
return None
188260
try:
189-
p = Path(last_file)
261+
p = Path(last)
190262
return str(p) if p.exists() and p.is_file() else None
191263
except Exception:
192-
logger.debug("Failed to check existence of last model file: %s", last_file)
264+
logger.debug("Failed to check existence of last model: %s", last)
193265
return None

0 commit comments

Comments
 (0)