Skip to content

Commit 4df2dbd

Browse files
committed
Fix device placement without CUDA
1 parent 1fa76b5 commit 4df2dbd

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

marble.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Dict
2+
from typing import Dict, Optional
33

44
import numpy as np
55
import torch
@@ -31,10 +31,18 @@ def get_session():
3131
return _session_cache
3232

3333

34+
def get_device():
35+
return "cuda" if torch.cuda.is_available() else "cpu"
36+
37+
3438
def setup_control_mlps(
35-
features: int = 1024, device: str = "cuda", dtype: torch.dtype = torch.float16
39+
features: int = 1024,
40+
device: Optional[str] = None,
41+
dtype: torch.dtype = torch.float16,
3642
) -> Dict[str, torch.nn.Module]:
3743
ret = {}
44+
if device is None:
45+
device = get_device()
3846
for mlp in CONTROL_MLPS:
3947
ret[mlp] = setup_control_mlp(mlp, features, device, dtype)
4048
return ret
@@ -43,12 +51,18 @@ def setup_control_mlps(
4351
def setup_control_mlp(
4452
material_parameter: str,
4553
features: int = 1024,
46-
device: str = "cuda",
54+
device: Optional[str] = None,
4755
dtype: torch.dtype = torch.float16,
4856
):
57+
if device is None:
58+
device = get_device()
59+
4960
net = control_mlp(features)
5061
net.load_state_dict(
51-
torch.load(os.path.join(file_dir, f"model_weights/{material_parameter}.pt"))
62+
torch.load(
63+
os.path.join(file_dir, f"model_weights/{material_parameter}.pt"),
64+
map_location=device
65+
)
5266
)
5367
net.to(device, dtype=dtype)
5468
net.eval()
@@ -95,9 +109,12 @@ def download_ip_adapter():
95109

96110

97111
def setup_pipeline(
98-
device: str = "cuda",
112+
device: Optional[str] = None,
99113
dtype: torch.dtype = torch.float16,
100114
):
115+
if device is None:
116+
device = get_device()
117+
101118
download_ip_adapter()
102119

103120
cur_block = ("up", 0, 1)
@@ -135,7 +152,10 @@ def setup_pipeline(
135152
)
136153

137154

138-
def get_dpt_model(device: str = "cuda", dtype: torch.dtype = torch.float16):
155+
def get_dpt_model(device: Optional[str] = None, dtype: torch.dtype = torch.float16):
156+
if device is None:
157+
device = get_device()
158+
139159
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
140160
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
141161
model.to(device, dtype=dtype)
@@ -144,9 +164,12 @@ def get_dpt_model(device: str = "cuda", dtype: torch.dtype = torch.float16):
144164

145165

146166
def run_dpt_depth(
147-
image: Image.Image, model, processor, device: str = "cuda"
167+
image: Image.Image, model, processor, device: Optional[str] = None
148168
) -> Image.Image:
149169
"""Run DPT depth estimation on an image."""
170+
if device is None:
171+
device = get_device()
172+
150173
# Prepare image
151174
inputs = processor(images=image, return_tensors="pt").to(device, dtype=model.dtype)
152175

0 commit comments

Comments
 (0)