Skip to content

Commit 0fca6d7

Browse files
author
Talmaj Marinc
committed
Add Optical Flow Loader.
1 parent 5c11f5d commit 0fca6d7

3 files changed

Lines changed: 95 additions & 20 deletions

File tree

comfy_extras/nodes_void.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,21 @@
44

55
import comfy
66
import comfy.model_management
7+
import comfy.model_patcher
78
import comfy.samplers
89
import comfy.utils
10+
import folder_paths
911
import node_helpers
1012
import nodes
1113
from comfy.utils import model_trange as trange
1214
from comfy_api.latest import ComfyExtension, io
15+
from torchvision.models.optical_flow import raft_large
1316
from typing_extensions import override
1417

15-
from comfy_extras.void_noise_warp import get_noise_from_video
18+
19+
from comfy_extras.void_noise_warp import RaftOpticalFlow, get_noise_from_video
20+
21+
OpticalFlow = io.Custom("OPTICAL_FLOW")
1622

1723
TEMPORAL_COMPRESSION = 4
1824
PATCH_SIZE_T = 2
@@ -38,6 +44,67 @@ def _valid_void_length(length: int) -> int:
3844
return (target_latent_t - 1) * TEMPORAL_COMPRESSION + 1
3945

4046

47+
class OpticalFlowLoader(io.ComfyNode):
48+
"""Load an optical flow model from ``models/optical_flow/``.
49+
50+
Only torchvision's RAFT-large format is recognized today (the model used
51+
by VOIDWarpedNoise). The checkpoint must be placed under
52+
``models/optical_flow/`` — ComfyUI never downloads optical-flow weights
53+
at runtime.
54+
"""
55+
56+
@classmethod
57+
def define_schema(cls):
58+
return io.Schema(
59+
node_id="OpticalFlowLoader",
60+
display_name="Load Optical Flow Model",
61+
category="loaders",
62+
inputs=[
63+
io.Combo.Input(
64+
"model_name",
65+
options=folder_paths.get_filename_list("optical_flow"),
66+
tooltip=(
67+
"Optical flow model to load. Files must be placed in the "
68+
"'optical_flow' folder. Today only torchvision's "
69+
"raft_large.pth is supported."
70+
),
71+
),
72+
],
73+
outputs=[
74+
OpticalFlow.Output(),
75+
],
76+
)
77+
78+
@classmethod
79+
def execute(cls, model_name) -> io.NodeOutput:
80+
81+
model_path = folder_paths.get_full_path_or_raise("optical_flow", model_name)
82+
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
83+
84+
has_raft_keys = (
85+
any(k.startswith("feature_encoder.") for k in sd)
86+
and any(k.startswith("context_encoder.") for k in sd)
87+
and any(k.startswith("update_block.") for k in sd)
88+
)
89+
if not has_raft_keys:
90+
raise ValueError(
91+
"Unrecognized optical flow model format: expected a torchvision "
92+
"RAFT-large state dict with 'feature_encoder.', 'context_encoder.' "
93+
"and 'update_block.' prefixes."
94+
)
95+
96+
model = raft_large(weights=None, progress=False)
97+
model.load_state_dict(sd)
98+
model.eval().to(torch.float32)
99+
100+
patcher = comfy.model_patcher.ModelPatcher(
101+
model,
102+
load_device=comfy.model_management.get_torch_device(),
103+
offload_device=comfy.model_management.unet_offload_device(),
104+
)
105+
return io.NodeOutput(patcher)
106+
107+
41108
class VOIDQuadmaskPreprocess(io.ComfyNode):
42109
"""Preprocess a quadmask video for VOID inpainting.
43110
@@ -222,6 +289,10 @@ def define_schema(cls):
222289
node_id="VOIDWarpedNoise",
223290
category="latent/video",
224291
inputs=[
292+
OpticalFlow.Input(
293+
"optical_flow",
294+
tooltip="Optical flow model from OpticalFlowLoader (RAFT-large).",
295+
),
225296
io.Image.Input("video", tooltip="Pass 1 output video frames [T, H, W, 3]"),
226297
io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8),
227298
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8),
@@ -236,7 +307,7 @@ def define_schema(cls):
236307
)
237308

238309
@classmethod
239-
def execute(cls, video, width, height, length, batch_size) -> io.NodeOutput:
310+
def execute(cls, optical_flow, video, width, height, length, batch_size) -> io.NodeOutput:
240311

241312
adjusted_length = _valid_void_length(length)
242313
if adjusted_length != length:
@@ -257,6 +328,9 @@ def execute(cls, video, width, height, length, batch_size) -> io.NodeOutput:
257328
# rest of the ComfyUI pipeline.
258329
device = comfy.model_management.get_torch_device()
259330

331+
comfy.model_management.load_model_gpu(optical_flow)
332+
raft = RaftOpticalFlow(optical_flow.model, device=device)
333+
260334
vid = video[:length].to(device)
261335
vid = comfy.utils.common_upscale(
262336
vid.movedim(-1, 1), width, height, "bilinear", "center"
@@ -269,6 +343,7 @@ def execute(cls, video, width, height, length, batch_size) -> io.NodeOutput:
269343

270344
warped = get_noise_from_video(
271345
vid_uint8,
346+
raft,
272347
noise_channels=16,
273348
resize_frames=FRAME,
274349
resize_flow=FLOW,
@@ -395,6 +470,7 @@ class VOIDExtension(ComfyExtension):
395470
@override
396471
async def get_node_list(self) -> list[type[io.ComfyNode]]:
397472
return [
473+
OpticalFlowLoader,
398474
VOIDQuadmaskPreprocess,
399475
VOIDInpaintConditioning,
400476
VOIDWarpedNoise,

comfy_extras/void_noise_warp.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually
1010
uses (torch THWC uint8 input, no background removal, no visualization, no disk
1111
I/O, default warp/noise params) have been inlined. External ``rp`` utilities
12-
have been replaced with equivalents from torch.nn.functional / einops /
13-
torchvision.
12+
have been replaced with equivalents from torch.nn.functional / einops. The
13+
RAFT optical-flow model itself is loaded offline via ``OpticalFlowLoader`` in
14+
``nodes_void.py`` and passed into ``get_noise_from_video`` by the caller; this
15+
module never downloads weights at runtime.
1416
"""
1517

1618
import logging
@@ -19,7 +21,6 @@
1921
import torch
2022
import torch.nn.functional as F
2123
from einops import rearrange
22-
from torchvision.models.optical_flow import raft_large
2324

2425
import comfy.model_management
2526

@@ -345,14 +346,20 @@ def __call__(self, dx, dy):
345346
# ---------------------------------------------------------------------------
346347

347348
class RaftOpticalFlow:
348-
"""Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow."""
349+
"""RAFT-large wrapper around a pre-loaded torchvision model.
349350
350-
def __init__(self, device=None):
351+
``model`` must be the ``torchvision.models.optical_flow.raft_large`` module
352+
with its weights already populated; this class is load-agnostic so the
353+
caller owns downloading/offload concerns (see ``OpticalFlowLoader`` in
354+
``nodes_void.py``). ``__call__`` returns a ``(2, H, W)`` flow.
355+
"""
356+
357+
def __init__(self, model, device=None):
351358
if device is None:
352359
device = comfy.model_management.get_torch_device()
353360
device = torch.device(device) if not isinstance(device, torch.device) else device
354361

355-
model = raft_large(weights="DEFAULT", progress=False).to(device)
362+
model = model.to(device)
356363
model.eval()
357364
self.device = device
358365
self.model = model
@@ -384,22 +391,13 @@ def __call__(self, from_image, to_image):
384391
return flow
385392

386393

387-
_raft_cache: dict = {}
388-
389-
390-
def _get_raft_model(device):
391-
key = str(device)
392-
if key not in _raft_cache:
393-
_raft_cache[key] = RaftOpticalFlow(device=device)
394-
return _raft_cache[key]
395-
396-
397394
# ---------------------------------------------------------------------------
398395
# Narrow entry point used by VOIDWarpedNoise
399396
# ---------------------------------------------------------------------------
400397

401398
def get_noise_from_video(
402399
video_frames: torch.Tensor,
400+
raft: RaftOpticalFlow,
403401
*,
404402
noise_channels: int = 16,
405403
resize_frames: float = 0.5,
@@ -411,6 +409,7 @@ def get_noise_from_video(
411409
412410
Args:
413411
video_frames: ``(T, H, W, 3)`` uint8 torch tensor.
412+
raft: Pre-loaded RAFT optical-flow wrapper (see ``RaftOpticalFlow``).
414413
noise_channels: Channels in the output noise.
415414
resize_frames: Pre-RAFT frame scale factor.
416415
resize_flow: Post-flow up-scale factor applied to the optical flow;
@@ -465,8 +464,6 @@ def get_noise_from_video(
465464
internal_h, internal_w, downscale_factor,
466465
)
467466

468-
raft = _get_raft_model(device)
469-
470467
with torch.no_grad():
471468
warper = NoiseWarper(
472469
c=noise_channels, h=internal_h, w=internal_w, device=device,

folder_paths.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454

5555
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
5656

57+
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
58+
5759
output_directory = os.path.join(base_path, "output")
5860
temp_directory = os.path.join(base_path, "temp")
5961
input_directory = os.path.join(base_path, "input")

0 commit comments

Comments
 (0)