66import folder_paths
77from typing_extensions import override
88from comfy_api .latest import ComfyExtension , io
9+ import comfy .model_management
910
1011try :
1112 from spandrel_extra_arches import EXTRA_REGISTRY
@@ -78,13 +79,15 @@ def execute(cls, upscale_model, image) -> io.NodeOutput:
7879 tile = 512
7980 overlap = 32
8081
82+ output_device = comfy .model_management .intermediate_device ()
83+
8184 oom = True
8285 try :
8386 while oom :
8487 try :
8588 steps = in_img .shape [0 ] * comfy .utils .get_tiled_scale_steps (in_img .shape [3 ], in_img .shape [2 ], tile_x = tile , tile_y = tile , overlap = overlap )
8689 pbar = comfy .utils .ProgressBar (steps )
87- s = comfy .utils .tiled_scale (in_img , lambda a : upscale_model (a ) , tile_x = tile , tile_y = tile , overlap = overlap , upscale_amount = upscale_model .scale , pbar = pbar )
90+ s = comfy .utils .tiled_scale (in_img , lambda a : upscale_model (a . float ()) , tile_x = tile , tile_y = tile , overlap = overlap , upscale_amount = upscale_model .scale , pbar = pbar , output_device = output_device )
8891 oom = False
8992 except Exception as e :
9093 model_management .raise_non_oom (e )
@@ -94,7 +97,7 @@ def execute(cls, upscale_model, image) -> io.NodeOutput:
9497 finally :
9598 upscale_model .to ("cpu" )
9699
97- s = torch .clamp (s .movedim (- 3 ,- 1 ), min = 0 , max = 1.0 )
100+ s = torch .clamp (s .movedim (- 3 ,- 1 ), min = 0 , max = 1.0 ). to ( comfy . model_management . intermediate_dtype ())
98101 return io .NodeOutput (s )
99102
100103 upscale = execute # TODO: remove
0 commit comments