Skip to content

Commit 74c7aa3

Browse files
committed
Use Engine enum and validate model path
Switch DLCProcessorSettings.model_type default from a raw string to the Engine enum (Engine.PYTORCH) and change DLCLiveProcessor.get_model_backend to return an Engine instead of a string for stronger typing. Update the GUI to validate empty model path input, reuse the resolved model backend (model_bknd) when building settings, improve the backend-detection error message, and allow selecting any file/directory in the model file dialog (adjusted name filter label for TensorFlow model directories). Also fix a relative import in dlclivegui.temp.__init__.py. Note: this changes the get_model_backend return type (callers expecting a string must use .value if needed).
1 parent 02821e1 commit 74c7aa3

File tree

4 files changed

+11
-9
lines changed

4 files changed

+11
-9
lines changed

dlclivegui/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ class DLCProcessorSettings(BaseModel):
241241
resize: float = Field(default=1.0, gt=0)
242242
precision: Precision = "FP32"
243243
additional_options: dict[str, Any] = Field(default_factory=dict)
244-
model_type: Engine = "pytorch"
244+
model_type: Engine = Engine.PYTORCH
245245
single_animal: bool = True
246246

247247
@field_validator("dynamic", mode="before")

dlclivegui/gui/main_window.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -876,11 +876,13 @@ def _parse_json(self, value: str) -> dict:
876876

877877
def _dlc_settings_from_ui(self) -> DLCProcessorSettings:
878878
model_path = self.model_path_edit.text().strip()
879+
if model_path == "":
880+
raise ValueError("Model path cannot be empty. Please enter a valid path to a DLCLive model file.")
879881
try:
880-
DLCLiveProcessor.get_model_backend(model_path)
882+
model_bknd = DLCLiveProcessor.get_model_backend(model_path)
881883
except Exception as e:
882884
raise RuntimeError(
883-
"Could not determine model backend from path."
885+
"Could not determine model backend from path. "
884886
"Please ensure the model file is valid and has an appropriate extension "
885887
"(.pt, .pth for PyTorch or model directory for TensorFlow)."
886888
) from e
@@ -891,7 +893,7 @@ def _dlc_settings_from_ui(self) -> DLCProcessorSettings:
891893
dynamic=self._config.dlc.dynamic, # Preserve from config
892894
resize=self._config.dlc.resize, # Preserve from config
893895
precision=self._config.dlc.precision, # Preserve from config
894-
model_type=DLCLiveProcessor.get_model_backend(model_path),
896+
model_type=model_bknd,
895897
# additional_options=self._parse_json(self.additional_options_edit.toPlainText()),
896898
)
897899

@@ -974,13 +976,13 @@ def _action_browse_model(self) -> None:
974976
preselect = self._model_path_store.suggest_selected_file()
975977

976978
dlg = QFileDialog(self, "Select DLCLive model file")
977-
dlg.setFileMode(QFileDialog.FileMode.ExistingFile)
979+
dlg.setFileMode(QFileDialog.FileMode.AnyFile)
978980
dlg.setNameFilters(
979981
[
980982
"Model files (*.pt *.pth)",
981983
"PyTorch models (*.pt *.pth)",
982984
# "TensorFlow models (*.pb)",
983-
"All files (*.*)",
985+
"TensorFlow model directory (*.*)",
984986
]
985987
)
986988
dlg.setDirectory(start_dir)

dlclivegui/services/dlc_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def __init__(self) -> None:
100100
self._processor_overhead_times: deque[float] = deque(maxlen=60)
101101

102102
@staticmethod
103-
def get_model_backend(model_path: str) -> str:
104-
return Engine.from_model_path(model_path).value
103+
def get_model_backend(model_path: str) -> Engine:
104+
return Engine.from_model_path(model_path)
105105

106106
def configure(self, settings: DLCProcessorSettings, processor: Any | None = None) -> None:
107107
self._settings = settings

dlclivegui/temp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from engine import Engine # type: ignore
1+
from .engine import Engine # type: ignore
22

33
__all__ = ["Engine"]

0 commit comments

Comments
 (0)