Skip to content

Commit c33d26c

Browse files
authored
fix: Proper memory estimation for frame interpolation when not using dynamic VRAM (#13698)
1 parent f3ea976 commit c33d26c

3 files changed

Lines changed: 10 additions & 7 deletions

File tree

comfy_extras/frame_interpolation_models/film_net.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels
199199
def get_dtype(self):
200200
return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype
201201

202+
def memory_used_forward(self, shape, dtype):
203+
return 1700 * shape[1] * shape[2] * dtype.itemsize
204+
202205
def _build_warp_grids(self, H, W, device):
203206
"""Pre-compute warp grids for all pyramid levels."""
204207
if (H, W) in self._warp_grids:

comfy_extras/frame_interpolation_models/ifnet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def __init__(self, head_ch=4, channels=(192, 128, 96, 64, 32), device=None, dtyp
7474
def get_dtype(self):
7575
return self.encode.cnn0.weight.dtype
7676

77+
def memory_used_forward(self, shape, dtype):
78+
return 300 * shape[1] * shape[2] * dtype.itemsize
79+
7780
def _build_warp_grids(self, H, W, device):
7881
if (H, W) in self._warp_grids:
7982
return

comfy_extras/nodes_frame_interpolation.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def execute(cls, model_name) -> io.NodeOutput:
3737
model = cls._detect_and_load(sd)
3838
dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32
3939
model.eval().to(dtype)
40-
patcher = comfy.model_patcher.ModelPatcher(
40+
patcher = comfy.model_patcher.CoreModelPatcher(
4141
model,
4242
load_device=model_management.get_torch_device(),
4343
offload_device=model_management.unet_offload_device(),
@@ -98,16 +98,13 @@ def execute(cls, interp_model, images, multiplier) -> io.NodeOutput:
9898
if num_frames < 2 or multiplier < 2:
9999
return io.NodeOutput(images)
100100

101-
model_management.load_model_gpu(interp_model)
102101
device = interp_model.load_device
103102
dtype = interp_model.model_dtype()
104103
inference_model = interp_model.model
105-
106-
# Free VRAM for inference activations (model weights + ~20x a single frame's worth)
107-
H, W = images.shape[1], images.shape[2]
108-
activation_mem = H * W * 3 * images.element_size() * 20
109-
model_management.free_memory(activation_mem, device)
104+
activation_mem = inference_model.memory_used_forward(images.shape, dtype)
105+
model_management.load_models_gpu([interp_model], memory_required=activation_mem)
110106
align = getattr(inference_model, "pad_align", 1)
107+
H, W = images.shape[1], images.shape[2]
111108

112109
# Prepare a single padded frame on device for determining output dimensions
113110
def prepare_frame(idx):

0 commit comments

Comments
 (0)