Skip to content

Commit 45e3499

Browse files
committed
Add temporary Engine enum and use in dlc_processor
Introduce a temporary dlclivegui.temp package with an Engine enum to detect model engines. engine.py provides Engine.TENSORFLOW and Engine.PYTORCH plus helpers from_model_type and from_model_path (checks extensions/pose_cfg.yaml and existence). Add __init__.py and update dlclivegui/services/dlc_processor.py to import Engine from dlclivegui.temp (with a TODO to switch to the upstream dlclive package when available).
1 parent 1bb70c7 commit 45e3499

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

dlclivegui/services/dlc_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
try: # pragma: no cover - optional dependency
2929
from dlclive import (
3030
DLCLive, # type: ignore
31-
Engine, # type: ignore
3231
)
32+
from dlclivegui.temp import Engine # type: ignore # TODO use main package one when released
3333
except Exception as e: # pragma: no cover - handled gracefully
3434
logger.error(f"dlclive package could not be imported: {e}")
3535
DLCLive = None # type: ignore[assignment]

dlclivegui/temp/__init__.py

Whitespace-only changes.

dlclivegui/temp/engine.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from enum import Enum
2+
from pathlib import Path
3+
4+
class Engine(Enum):
5+
TENSORFLOW = "tensorflow"
6+
PYTORCH = "pytorch"
7+
8+
@classmethod
9+
def from_model_type(cls, model_type: str) -> "Engine":
10+
if model_type.lower() == "pytorch":
11+
return cls.PYTORCH
12+
elif model_type.lower() in ("tensorflow", "base", "tensorrt", "lite"):
13+
return cls.TENSORFLOW
14+
else:
15+
raise ValueError(f"Unknown model type: {model_type}")
16+
17+
@classmethod
18+
def from_model_path(cls, model_path: str | Path) -> "Engine":
19+
path = Path(model_path)
20+
21+
if not path.exists():
22+
raise FileNotFoundError(f"Model path does not exist: {model_path}")
23+
24+
if path.is_dir():
25+
has_cfg = (path / "pose_cfg.yaml").is_file()
26+
has_pb = any(p.suffix == ".pb" for p in path.glob("*.pb"))
27+
if has_cfg and has_pb:
28+
return cls.TENSORFLOW
29+
elif path.is_file():
30+
if path.suffix == ".pt":
31+
return cls.PYTORCH
32+
33+
raise ValueError(f"Could not determine engine from model path: {model_path}")

0 commit comments

Comments
 (0)