44
55import comfy
66import comfy .model_management
7+ import comfy .model_patcher
78import comfy .samplers
89import comfy .utils
10+ import folder_paths
911import node_helpers
1012import nodes
1113from comfy .utils import model_trange as trange
1214from comfy_api .latest import ComfyExtension , io
15+ from torchvision .models .optical_flow import raft_large
1316from typing_extensions import override
1417
15- from comfy_extras .void_noise_warp import get_noise_from_video
18+
19+ from comfy_extras .void_noise_warp import RaftOpticalFlow , get_noise_from_video
20+
21+ OpticalFlow = io .Custom ("OPTICAL_FLOW" )
1622
1723TEMPORAL_COMPRESSION = 4
1824PATCH_SIZE_T = 2
@@ -38,6 +44,67 @@ def _valid_void_length(length: int) -> int:
3844 return (target_latent_t - 1 ) * TEMPORAL_COMPRESSION + 1
3945
4046
47+ class OpticalFlowLoader (io .ComfyNode ):
48+ """Load an optical flow model from ``models/optical_flow/``.
49+
50+ Only torchvision's RAFT-large format is recognized today (the model used
51+ by VOIDWarpedNoise). The checkpoint must be placed under
52+ ``models/optical_flow/`` — ComfyUI never downloads optical-flow weights
53+ at runtime.
54+ """
55+
56+ @classmethod
57+ def define_schema (cls ):
58+ return io .Schema (
59+ node_id = "OpticalFlowLoader" ,
60+ display_name = "Load Optical Flow Model" ,
61+ category = "loaders" ,
62+ inputs = [
63+ io .Combo .Input (
64+ "model_name" ,
65+ options = folder_paths .get_filename_list ("optical_flow" ),
66+ tooltip = (
67+ "Optical flow model to load. Files must be placed in the "
68+ "'optical_flow' folder. Today only torchvision's "
69+ "raft_large.pth is supported."
70+ ),
71+ ),
72+ ],
73+ outputs = [
74+ OpticalFlow .Output (),
75+ ],
76+ )
77+
78+ @classmethod
79+ def execute (cls , model_name ) -> io .NodeOutput :
80+
81+ model_path = folder_paths .get_full_path_or_raise ("optical_flow" , model_name )
82+ sd = comfy .utils .load_torch_file (model_path , safe_load = True )
83+
84+ has_raft_keys = (
85+ any (k .startswith ("feature_encoder." ) for k in sd )
86+ and any (k .startswith ("context_encoder." ) for k in sd )
87+ and any (k .startswith ("update_block." ) for k in sd )
88+ )
89+ if not has_raft_keys :
90+ raise ValueError (
91+ "Unrecognized optical flow model format: expected a torchvision "
92+ "RAFT-large state dict with 'feature_encoder.', 'context_encoder.' "
93+ "and 'update_block.' prefixes."
94+ )
95+
96+ model = raft_large (weights = None , progress = False )
97+ model .load_state_dict (sd )
98+ model .eval ().to (torch .float32 )
99+
100+ patcher = comfy .model_patcher .ModelPatcher (
101+ model ,
102+ load_device = comfy .model_management .get_torch_device (),
103+ offload_device = comfy .model_management .unet_offload_device (),
104+ )
105+ return io .NodeOutput (patcher )
106+
107+
41108class VOIDQuadmaskPreprocess (io .ComfyNode ):
42109 """Preprocess a quadmask video for VOID inpainting.
43110
@@ -222,6 +289,10 @@ def define_schema(cls):
222289 node_id = "VOIDWarpedNoise" ,
223290 category = "latent/video" ,
224291 inputs = [
292+ OpticalFlow .Input (
293+ "optical_flow" ,
294+ tooltip = "Optical flow model from OpticalFlowLoader (RAFT-large)." ,
295+ ),
225296 io .Image .Input ("video" , tooltip = "Pass 1 output video frames [T, H, W, 3]" ),
226297 io .Int .Input ("width" , default = 672 , min = 16 , max = nodes .MAX_RESOLUTION , step = 8 ),
227298 io .Int .Input ("height" , default = 384 , min = 16 , max = nodes .MAX_RESOLUTION , step = 8 ),
@@ -236,7 +307,7 @@ def define_schema(cls):
236307 )
237308
238309 @classmethod
239- def execute (cls , video , width , height , length , batch_size ) -> io .NodeOutput :
310+ def execute (cls , optical_flow , video , width , height , length , batch_size ) -> io .NodeOutput :
240311
241312 adjusted_length = _valid_void_length (length )
242313 if adjusted_length != length :
@@ -257,6 +328,9 @@ def execute(cls, video, width, height, length, batch_size) -> io.NodeOutput:
257328 # rest of the ComfyUI pipeline.
258329 device = comfy .model_management .get_torch_device ()
259330
331+ comfy .model_management .load_model_gpu (optical_flow )
332+ raft = RaftOpticalFlow (optical_flow .model , device = device )
333+
260334 vid = video [:length ].to (device )
261335 vid = comfy .utils .common_upscale (
262336 vid .movedim (- 1 , 1 ), width , height , "bilinear" , "center"
@@ -269,6 +343,7 @@ def execute(cls, video, width, height, length, batch_size) -> io.NodeOutput:
269343
270344 warped = get_noise_from_video (
271345 vid_uint8 ,
346+ raft ,
272347 noise_channels = 16 ,
273348 resize_frames = FRAME ,
274349 resize_flow = FLOW ,
@@ -395,6 +470,7 @@ class VOIDExtension(ComfyExtension):
395470 @override
396471 async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
397472 return [
473+ OpticalFlowLoader ,
398474 VOIDQuadmaskPreprocess ,
399475 VOIDInpaintConditioning ,
400476 VOIDWarpedNoise ,
0 commit comments