Skip to content

Commit 38da250

Browse files
committed
fix(depth-estimation): replace dead torch.hub.load with HF hub + pip package
torch.hub.load('LiheYoung/Depth-Anything-V2', ...) returns 404. Switch to direct DepthAnythingV2 class from depth_anything_v2 pip package with weights downloaded via huggingface_hub.hf_hub_download (cached). Tested: model loads successfully on MPS (Apple Silicon).
1 parent a7bb895 commit 38da250

1 file changed

Lines changed: 38 additions & 6 deletions

File tree

skills/transformation/depth-estimation/scripts/transform.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def parse_extra_args(self, parser: argparse.ArgumentParser):
5959

6060
def load_model(self, config: dict) -> dict:
6161
import torch
62+
from depth_anything_v2.dpt import DepthAnythingV2
63+
from huggingface_hub import hf_hub_download
6264

6365
model_name = config.get("model", "depth-anything-v2-small")
6466
self.colormap_id = COLORMAP_MAP.get(config.get("colormap", "inferno"), 1)
@@ -67,13 +69,43 @@ def load_model(self, config: dict) -> dict:
6769

6870
_log(f"Loading {model_name} on {self.device}", self._tag)
6971

70-
# Load model via torch hub
71-
hub_name = model_name.replace("-", "_")
72-
self.model = torch.hub.load(
73-
"LiheYoung/Depth-Anything-V2",
74-
hub_name,
75-
trust_repo=True,
72+
# Model configs: encoder name, features, HF repo, weight filename
73+
MODEL_CONFIGS = {
74+
"depth-anything-v2-small": {
75+
"encoder": "vits", "features": 64,
76+
"out_channels": [48, 96, 192, 384],
77+
"repo": "depth-anything/Depth-Anything-V2-Small",
78+
"filename": "depth_anything_v2_vits.pth",
79+
},
80+
"depth-anything-v2-base": {
81+
"encoder": "vitb", "features": 128,
82+
"out_channels": [96, 192, 384, 768],
83+
"repo": "depth-anything/Depth-Anything-V2-Base",
84+
"filename": "depth_anything_v2_vitb.pth",
85+
},
86+
"depth-anything-v2-large": {
87+
"encoder": "vitl", "features": 256,
88+
"out_channels": [256, 512, 1024, 1024],
89+
"repo": "depth-anything/Depth-Anything-V2-Large",
90+
"filename": "depth_anything_v2_vitl.pth",
91+
},
92+
}
93+
94+
cfg = MODEL_CONFIGS.get(model_name)
95+
if not cfg:
96+
raise ValueError(f"Unknown model: {model_name}. Choose from: {list(MODEL_CONFIGS.keys())}")
97+
98+
# Download weights from HuggingFace Hub (cached after first download)
99+
_log(f"Downloading weights from HF: {cfg['repo']}", self._tag)
100+
weights_path = hf_hub_download(cfg["repo"], cfg["filename"])
101+
102+
# Build model from pip package
103+
self.model = DepthAnythingV2(
104+
encoder=cfg["encoder"],
105+
features=cfg["features"],
106+
out_channels=cfg["out_channels"],
76107
)
108+
self.model.load_state_dict(torch.load(weights_path, map_location=self.device, weights_only=True))
77109
self.model.to(self.device)
78110
self.model.eval()
79111

0 commit comments

Comments
 (0)