Skip to content

Frame Interpolate can fail with CUDA input / CPU weight error #13583

@nomadoor

Description

@nomadoor

Custom Node Testing

Expected Behavior

Using the Frame Interpolate node should correctly interpolate frames (e.g. doubling FPS) without errors.

Actual Behavior

When using the Frame Interpolate node added in PR #13258, it can fail with this error:

RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.HalfTensor) should be the same

It looks like the input tensor is on CUDA, while the interpolation model may be offloaded to CPU (possibly due to free_memory()).

I am not sure this is the correct fix, but the error goes away for me if the interpolation model is explicitly kept on device and moved to the correct dtype/device before inference:

-        model_management.free_memory(activation_mem, device)
+        model_management.free_memory(activation_mem, device, keep_loaded=[interp_model])
+        inference_model.to(device=device, dtype=dtype)

So this may be a device placement issue in the Frame Interpolate node.

Steps to Reproduce

Image

FILM.json

Debug Logs

Requested to load FILMNet
loaded completely; 8128.68 MB usable, 65.68 MB loaded, full load: True
Frame interpolation:   0%|                                                                     | 0/120 [00:00<?, ?it/s]
!!! Exception during processing !!! Input type (torch.cuda.HalfTensor) and weight type (torch.HalfTensor) should be the same
Traceback (most recent call last):
  File "D:\AI\ComfyUI_windows_portable\ComfyUI\execution.py", line 534, in execute
    output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
                                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\ComfyUI\execution.py", line 334, in get_output_data
    return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\ComfyUI\execution.py", line 308, in _async_map_node_over_list
    await process_inputs(input_dict, i)
  File "D:\AI\ComfyUI_windows_portable\ComfyUI\execution.py", line 296, in process_inputs
    result = f(**inputs)
             ^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\ComfyUI\comfy_api\internal\__init__.py", line 149, in wrapped_func
    return method(locked_class, **inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\ComfyUI\comfy_api\latest\_io.py", line 1826, in EXECUTE_NORMALIZED
    to_return = cls.execute(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\ComfyUI\comfy_extras\nodes_frame_interpolation.py", line 154, in execute
    feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single)
                                                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\ComfyUI\comfy_extras\frame_interpolation_models\film_net.py", line 221, in extract_features
    feature_pyramid = self.extract(image_pyramid)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\ComfyUI\comfy_extras\frame_interpolation_models\film_net.py", line 99, in forward
    sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels))
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\ComfyUI\comfy_extras\frame_interpolation_models\film_net.py", line 85, in forward
    head = layer(head)
           ^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\container.py", line 240, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\ComfyUI\comfy_extras\frame_interpolation_models\film_net.py", line 24, in forward
    x = self.conv(x)
        ^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\ComfyUI\comfy\ops.py", line 428, in forward
    return super().forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\conv.py", line 554, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\AI\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\conv.py", line 549, in _conv_forward
    return F.conv2d(
           ^^^^^^^^^
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.HalfTensor) should be the same

Other

ComfyUI 0.19.3
ComfyUI_frontend v1.42.15

System Info

OS: win32
Python Version: 3.12.10 (tags/v3.12.10:0cc8128, Apr 8 2025, 12:21:36) [MSC v.1943 64 bit (AMD64)]
Embedded Python: true
Pytorch Version: 2.7.1+cu128
Arguments: ComfyUI\main.py --disable-auto-launch --reserve-vram 2.0 --disable-smart-memory
RAM Total: 95.87 GB
RAM Free: 40.42 GB
Templates Version: 0.9.62

Devices

  • cuda:0 NVIDIA GeForce RTX 4070 Ti : cudaMallocAsync (cuda)
    VRAM Total: 11.99 GB
    VRAM Free: 10.74 GB
    Torch VRAM Total: 32 MB
    Torch VRAM Free: 23.88 MB

Metadata

Metadata

Assignees

Labels

Potential BugUser is reporting a bug. This should be tested.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions