@@ -37,7 +37,7 @@ def execute(cls, model_name) -> io.NodeOutput:
3737 model = cls ._detect_and_load (sd )
3838 dtype = torch .float16 if model_management .should_use_fp16 (model_management .get_torch_device ()) else torch .float32
3939 model .eval ().to (dtype )
40- patcher = comfy .model_patcher .ModelPatcher (
40+ patcher = comfy .model_patcher .CoreModelPatcher (
4141 model ,
4242 load_device = model_management .get_torch_device (),
4343 offload_device = model_management .unet_offload_device (),
@@ -98,16 +98,13 @@ def execute(cls, interp_model, images, multiplier) -> io.NodeOutput:
9898 if num_frames < 2 or multiplier < 2 :
9999 return io .NodeOutput (images )
100100
101- model_management .load_model_gpu (interp_model )
102101 device = interp_model .load_device
103102 dtype = interp_model .model_dtype ()
104103 inference_model = interp_model .model
105-
106- # Free VRAM for inference activations (model weights + ~20x a single frame's worth)
107- H , W = images .shape [1 ], images .shape [2 ]
108- activation_mem = H * W * 3 * images .element_size () * 20
109- model_management .free_memory (activation_mem , device )
104+ activation_mem = inference_model .memory_used_forward (images .shape , dtype )
105+ model_management .load_models_gpu ([interp_model ], memory_required = activation_mem )
110106 align = getattr (inference_model , "pad_align" , 1 )
107+ H , W = images .shape [1 ], images .shape [2 ]
111108
112109 # Prepare a single padded frame on device for determining output dimensions
113110 def prepare_frame (idx ):
0 commit comments