Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
__pycache__/
*.py[cod]


# Distribution / packaging
workspace/
build/
Expand Down Expand Up @@ -34,3 +33,4 @@ input_images*/
src/debug_main.py
temp*.png
/outputs
.idea
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "depth-anything-3"
version = "0.0.0"
description = "Depth Anything 3"
readme = "README.md"
requires-python = ">=3.9, <=3.13"
requires-python = ">=3.10, <=3.13"
license = { text = "Apache-2.0" }
authors = [{ name = "Your Name" }]

Expand All @@ -21,7 +21,7 @@ dependencies = [
"imageio",
"numpy<2",
"opencv-python",
"xformers",
"xformers; platform_system!='Darwin'",
"open3d",
"fastapi",
"uvicorn",
Expand All @@ -43,7 +43,9 @@ dependencies = [

[project.optional-dependencies]
app = ["gradio>=5", "pillow>=9.0"] # requires that python3>=3.10
gs = ["gsplat @ git+https://github.com/nerfstudio-project/gsplat.git@0b4dddf04cb687367602c01196913cde6a743d70"]
gs = [
"gsplat>=1.0.0; platform_system!='Darwin'"
]
all = ["depth-anything-3[app,gs]"]


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ huggingface_hub
imageio
numpy<2
opencv-python
xformers
xformers; platform_system!='Darwin'
open3d
fastapi
uvicorn
Expand Down
75 changes: 62 additions & 13 deletions src/depth_anything_3/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from huggingface_hub import PyTorchModelHubMixin
from PIL import Image

from depth_anything_3.cache import get_model_cache
from depth_anything_3.cfg import create_object, load_config
from depth_anything_3.registry import MODEL_REGISTRY
from depth_anything_3.specs import Prediction
Expand Down Expand Up @@ -72,29 +73,64 @@ class DepthAnything3(nn.Module, PyTorchModelHubMixin):

_commit_hash: str | None = None # Set by mixin when loading from Hub

def __init__(self, model_name: str = "da3-large", **kwargs):
def __init__(self, model_name: str = "da3-large", device: str | torch.device | None = None, use_cache: bool = True, **kwargs):
"""
Initialize DepthAnything3 with specified preset.

Args:
model_name: The name of the model preset to use.
Examples: 'da3-giant', 'da3-large', 'da3metric-large', 'da3nested-giant-large'.
**kwargs: Additional keyword arguments (currently unused).
model_name: The name of the model preset to use.
Examples: 'da3-giant', 'da3-large', 'da3metric-large', 'da3nested-giant-large'.
device: Target device ('cuda', 'mps', 'cpu'). If None, auto-detect.
use_cache: Whether to use model caching (default: True).
Set to False to force reload model from disk.
**kwargs: Additional keyword arguments (currently unused).
"""
super().__init__()
self.model_name = model_name
self.use_cache = use_cache

# Build the underlying network
# Determine device
if device is None:
device = self._auto_detect_device()
self.device = torch.device(device) if isinstance(device, str) else device

# Load model configuration
self.config = load_config(MODEL_REGISTRY[self.model_name])
self.model = create_object(self.config)

# Build or retrieve model from cache
if use_cache:
cache = get_model_cache()
self.model = cache.get(
model_name=self.model_name,
device=self.device,
loader_fn=lambda: self._create_model()
)
else:
logger.info(f"Model cache disabled, loading {self.model_name} from disk")
self.model = self._create_model()

# Ensure model is on correct device and in eval mode
self.model = self.model.to(self.device)
self.model.eval()

# Initialize processors
self.input_processor = InputProcessor()
self.output_processor = OutputProcessor()

# Device management (set by user)
self.device = None
def _auto_detect_device(self) -> torch.device:
"""Auto-detect best available device."""
if torch.cuda.is_available():
return torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")

def _create_model(self) -> nn.Module:
"""Create and return new model instance."""
model = create_object(self.config)
model.eval()
return model

@torch.inference_mode()
def forward(
Expand Down Expand Up @@ -304,20 +340,33 @@ def _prepare_model_inputs(
extrinsics: torch.Tensor | None,
intrinsics: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""Prepare tensors for model input."""
"""
Prepare tensors for model input with optimized device transfer.

Uses non_blocking=True for async CPU→GPU transfers, which overlaps
data transfer with compute when possible.
"""
device = self._get_model_device()

# Move images to model device
# Pin memory for faster CPU→GPU transfer (CUDA only)
if device.type == "cuda" and imgs_cpu.device.type == "cpu":
imgs_cpu = imgs_cpu.pin_memory()

# Move images to model device with non-blocking transfer
imgs = imgs_cpu.to(device, non_blocking=True)[None].float()

# Convert camera parameters to tensors
# Convert camera parameters to tensors with non-blocking transfer
ex_t = (
extrinsics.to(device, non_blocking=True)[None].float()
extrinsics.pin_memory().to(device, non_blocking=True)[None].float()
if extrinsics is not None and device.type == "cuda"
else extrinsics.to(device, non_blocking=True)[None].float()
if extrinsics is not None
else None
)
in_t = (
intrinsics.to(device, non_blocking=True)[None].float()
intrinsics.pin_memory().to(device, non_blocking=True)[None].float()
if intrinsics is not None and device.type == "cuda"
else intrinsics.to(device, non_blocking=True)[None].float()
if intrinsics is not None
else None
)
Expand Down
189 changes: 189 additions & 0 deletions src/depth_anything_3/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""
Model caching utilities for Depth Anything 3.

Provides model caching functionality to avoid reloading model weights on every instantiation.
This significantly reduces latency for repeated model creation (2-5s gain).
"""

from __future__ import annotations

import threading
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn

from depth_anything_3.utils.logger import logger


class ModelCache:
"""
Thread-safe singleton cache for Depth Anything 3 models.

Caches loaded model weights to avoid reloading from disk on every instantiation.
Each unique combination of (model_name, device) is cached separately.

Usage:
cache = ModelCache()
model = cache.get(model_name, device, loader_fn)
# loader_fn is only called if cache miss

Thread Safety:
Uses threading.Lock to ensure thread-safe access to cache.

Memory Management:
- Models are kept in cache until explicitly cleared
- Use clear() to free memory when needed
- Use clear_device() to clear specific device models
"""

_instance: Optional["ModelCache"] = None
_lock = threading.Lock()

def __new__(cls):
"""Singleton pattern to ensure single cache instance."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance

def __init__(self):
"""Initialize cache storage."""
if self._initialized:
return

self._cache: Dict[Tuple[str, str], nn.Module] = {}
self._cache_lock = threading.Lock()
self._initialized = True
logger.info("ModelCache initialized")

def get(
self,
model_name: str,
device: torch.device | str,
loader_fn: callable,
) -> nn.Module:
"""
Get cached model or load if not in cache.

Args:
model_name: Name of the model (e.g., "da3-large")
device: Target device (cuda, mps, cpu)
loader_fn: Function to load model if cache miss
Should return nn.Module

Returns:
Cached or freshly loaded model on specified device

Example:
>>> cache = ModelCache()
>>> model = cache.get(
... "da3-large",
... "cuda",
... lambda: create_model()
... )
"""
device_str = str(device)
cache_key = (model_name, device_str)

with self._cache_lock:
if cache_key in self._cache:
logger.debug(f"Model cache HIT: {model_name} on {device_str}")
return self._cache[cache_key]

logger.info(f"Model cache MISS: {model_name} on {device_str}. Loading...")
model = loader_fn()
self._cache[cache_key] = model
logger.info(f"Model cached: {model_name} on {device_str}")

return model

def clear(self) -> None:
"""
Clear entire cache and free memory.

Removes all cached models and forces garbage collection.
Useful when switching between many different models.
"""
with self._cache_lock:
num_cached = len(self._cache)
self._cache.clear()

# Force garbage collection to free GPU memory
import gc

gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if hasattr(torch, "mps") and torch.backends.mps.is_available():
torch.mps.empty_cache()

logger.info(f"Model cache cleared ({num_cached} models removed)")

def clear_device(self, device: torch.device | str) -> None:
"""
Clear all models on specific device.

Args:
device: Device to clear (e.g., "cuda", "mps", "cpu")

Example:
>>> cache = ModelCache()
>>> cache.clear_device("cuda") # Clear all CUDA models
"""
device_str = str(device)

with self._cache_lock:
keys_to_remove = [key for key in self._cache if key[1] == device_str]
for key in keys_to_remove:
del self._cache[key]

# Free device memory
if "cuda" in device_str and torch.cuda.is_available():
torch.cuda.empty_cache()
elif "mps" in device_str and hasattr(torch, "mps") and torch.backends.mps.is_available():
torch.mps.empty_cache()

logger.info(f"Model cache cleared for device {device_str} ({len(keys_to_remove)} models removed)")

def get_cache_info(self) -> Dict[str, int]:
"""
Get cache statistics.

Returns:
Dictionary with cache info:
- total: Total number of cached models
- by_device: Number of models per device
"""
with self._cache_lock:
info = {
"total": len(self._cache),
"by_device": {},
}

for model_name, device_str in self._cache.keys():
if device_str not in info["by_device"]:
info["by_device"][device_str] = 0
info["by_device"][device_str] += 1

return info


# Global singleton instance
_global_cache = ModelCache()


def get_model_cache() -> ModelCache:
"""
Get global model cache instance.

Returns:
Singleton ModelCache instance

Example:
>>> from depth_anything_3.cache import get_model_cache
>>> cache = get_model_cache()
>>> cache.clear()
"""
return _global_cache
7 changes: 5 additions & 2 deletions src/depth_anything_3/utils/io/input_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,16 @@ def __call__(
proc_imgs, out_sizes, out_ixts = self._unify_batch_shapes(proc_imgs, out_sizes, out_ixts)

batch_tensor = self._stack_batch(proc_imgs)

# Zero-copy conversion: torch.from_numpy shares memory with numpy array
# Only works when array is C-contiguous (which np.asarray ensures)
out_exts = (
torch.from_numpy(np.asarray(out_exts)).float()
torch.from_numpy(np.ascontiguousarray(np.asarray(out_exts))).float()
if out_exts is not None and out_exts[0] is not None
else None
)
out_ixts = (
torch.from_numpy(np.asarray(out_ixts)).float()
torch.from_numpy(np.ascontiguousarray(np.asarray(out_ixts))).float()
if out_ixts is not None and out_ixts[0] is not None
else None
)
Expand Down
Loading