11import os
2- from typing import Dict
2+ from typing import Dict , Optional
33
44import numpy as np
55import torch
@@ -31,10 +31,18 @@ def get_session():
3131 return _session_cache
3232
3333
34+ def get_device ():
35+ return "cuda" if torch .cuda .is_available () else "cpu"
36+
37+
3438def setup_control_mlps (
35- features : int = 1024 , device : str = "cuda" , dtype : torch .dtype = torch .float16
39+ features : int = 1024 ,
40+ device : Optional [str ] = None ,
41+ dtype : torch .dtype = torch .float16 ,
3642) -> Dict [str , torch .nn .Module ]:
3743 ret = {}
44+ if device is None :
45+ device = get_device ()
3846 for mlp in CONTROL_MLPS :
3947 ret [mlp ] = setup_control_mlp (mlp , features , device , dtype )
4048 return ret
@@ -43,12 +51,18 @@ def setup_control_mlps(
4351def setup_control_mlp (
4452 material_parameter : str ,
4553 features : int = 1024 ,
46- device : str = "cuda" ,
54+ device : Optional [ str ] = None ,
4755 dtype : torch .dtype = torch .float16 ,
4856):
57+ if device is None :
58+ device = get_device ()
59+
4960 net = control_mlp (features )
5061 net .load_state_dict (
51- torch .load (os .path .join (file_dir , f"model_weights/{ material_parameter } .pt" ))
62+ torch .load (
63+ os .path .join (file_dir , f"model_weights/{ material_parameter } .pt" ),
64+ map_location = device
65+ )
5266 )
5367 net .to (device , dtype = dtype )
5468 net .eval ()
@@ -95,9 +109,12 @@ def download_ip_adapter():
95109
96110
97111def setup_pipeline (
98- device : str = "cuda" ,
112+ device : Optional [ str ] = None ,
99113 dtype : torch .dtype = torch .float16 ,
100114):
115+ if device is None :
116+ device = get_device ()
117+
101118 download_ip_adapter ()
102119
103120 cur_block = ("up" , 0 , 1 )
@@ -135,7 +152,10 @@ def setup_pipeline(
135152 )
136153
137154
138- def get_dpt_model (device : str = "cuda" , dtype : torch .dtype = torch .float16 ):
155+ def get_dpt_model (device : Optional [str ] = None , dtype : torch .dtype = torch .float16 ):
156+ if device is None :
157+ device = get_device ()
158+
139159 image_processor = DPTImageProcessor .from_pretrained ("Intel/dpt-hybrid-midas" )
140160 model = DPTForDepthEstimation .from_pretrained ("Intel/dpt-hybrid-midas" )
141161 model .to (device , dtype = dtype )
@@ -144,9 +164,12 @@ def get_dpt_model(device: str = "cuda", dtype: torch.dtype = torch.float16):
144164
145165
146166def run_dpt_depth (
147- image : Image .Image , model , processor , device : str = "cuda"
167+ image : Image .Image , model , processor , device : Optional [ str ] = None
148168) -> Image .Image :
149169 """Run DPT depth estimation on an image."""
170+ if device is None :
171+ device = get_device ()
172+
150173 # Prepare image
151174 inputs = processor (images = image , return_tensors = "pt" ).to (device , dtype = model .dtype )
152175
0 commit comments