Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 93 additions & 44 deletions vfi_models/rife/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def INPUT_TYPES(s):
),
"frames": ("IMAGE", ),
"clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}),
"multiplier": ("INT", {"default": 2, "min": 1}),
"fps_mode": ("BOOLEAN", {"default": False,
"tooltip": "When enabled, use source_fps and target_fps to determine output frame count "
"instead of multiplier."}),
"multiplier": ("INT", {"default": 2, "min": 1,
"tooltip": "Used when fps_mode is off. Multiplies each input frame pair by this factor."}),
"fast_mode": ("BOOLEAN", {"default": True}),
"ensemble": ("BOOLEAN", {"default": True}),
"scale_factor": ([0.25, 0.5, 1.0, 2.0, 4.0], {"default": 1.0}),
Expand All @@ -64,7 +68,13 @@ def INPUT_TYPES(s):
"Set to 1 for the most conservative behaviour."}),
},
"optional": {
"optional_interpolation_states": ("INTERPOLATION_STATES", )
"optional_interpolation_states": ("INTERPOLATION_STATES", ),
"source_fps": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 300.0, "step": 0.001,
"tooltip": "Frame rate of the input frames. "
"Set both source_fps and target_fps (>0) to use FPS mode instead of multiplier."}),
"target_fps": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 300.0, "step": 0.001,
"tooltip": "Desired output frame rate. "
"Set both source_fps and target_fps (>0) to use FPS mode instead of multiplier."}),
}
}

Expand All @@ -77,6 +87,7 @@ def vfi(
ckpt_name: typing.AnyStr,
frames: torch.Tensor,
clear_cache_after_n_frames: int = 10,
fps_mode: bool = False,
multiplier: typing.SupportsInt = 2,
fast_mode: bool = False,
ensemble: bool = False,
Expand All @@ -85,6 +96,8 @@ def vfi(
torch_compile: bool = False,
batch_size: int = 1,
optional_interpolation_states: InterpolationStateList = None,
source_fps: float = 0.0,
target_fps: float = 0.0,
**kwargs
):
"""
Expand Down Expand Up @@ -139,32 +152,66 @@ def vfi(

frames = preprocess_frames(frames)

# Normalise multiplier to a per-pair list
n_pairs = len(frames) - 1
if isinstance(multiplier, int):
multipliers = [int(multiplier)] * n_pairs
else:
multipliers = list(map(int, multiplier))
multipliers += [2] * (n_pairs - len(multipliers))

n_input = len(frames)
n_pairs = n_input - 1
scale_list = [8 / scale_factor, 4 / scale_factor, 2 / scale_factor, 1 / scale_factor]

# Build a flat list of all (pair_idx, timestep) tasks, skipping excluded pairs.
# Each task produces exactly one intermediate frame.
# output_specs: describes every output frame in order.
# ('orig', frame_idx) — copy frames[frame_idx] directly
# ('interp', task_idx) — result of tasks[task_idx]
output_specs: typing.List[typing.Tuple] = []
tasks: typing.List[typing.Tuple[int, float]] = []
tasks_remaining_per_pair: typing.Dict[int, int] = {}
for pair_idx in range(n_pairs):
if optional_interpolation_states is not None and optional_interpolation_states.is_frame_skipped(pair_idx):
tasks_remaining_per_pair[pair_idx] = 0
continue
m = multipliers[pair_idx]
n_steps = max(m - 1, 0)
tasks_remaining_per_pair[pair_idx] = n_steps
for step in range(1, m):
tasks.append((pair_idx, step / m))

# Storage for intermediate frames, keyed by pair index
results: typing.Dict[int, typing.List[torch.Tensor]] = {i: [] for i in range(n_pairs)}

use_fps_mode = fps_mode and source_fps > 0 and target_fps > 0

if use_fps_mode:
# FPS mode: place output frames at exact target_fps timestamps.
if abs(source_fps - target_fps) < 0.01:
print("Comfy-VFI: source_fps ≈ target_fps, returning frames unchanged.")
return (postprocess_frames(frames.to(torch.float32)),)

n_output = max(2, round((n_input - 1) * target_fps / source_fps) + 1)
print(f"Comfy-VFI: FPS mode {source_fps} → {target_fps} fps ({n_input} → {n_output} frames)")

for out_i in range(n_output):
# Map output frame index to a continuous position in input-frame space
input_pos = out_i * (n_input - 1) / (n_output - 1)
pair_idx = int(input_pos)
alpha = input_pos - pair_idx

if pair_idx >= n_pairs or alpha < 1e-6:
# Exact input frame (or past the last)
output_specs.append(('orig', min(pair_idx, n_input - 1)))
elif optional_interpolation_states is not None and optional_interpolation_states.is_frame_skipped(pair_idx):
output_specs.append(('orig', pair_idx))
else:
output_specs.append(('interp', len(tasks)))
tasks.append((pair_idx, alpha))

else:
# Multiplier mode: insert (multiplier-1) evenly-spaced frames between each pair.
if isinstance(multiplier, int):
multipliers = [int(multiplier)] * n_pairs
else:
multipliers = list(map(int, multiplier))
multipliers += [2] * (n_pairs - len(multipliers))

tasks_remaining_per_pair: typing.Dict[int, int] = {}
for pair_idx in range(n_pairs):
output_specs.append(('orig', pair_idx))
if optional_interpolation_states is not None and optional_interpolation_states.is_frame_skipped(pair_idx):
tasks_remaining_per_pair[pair_idx] = 0
continue
m = multipliers[pair_idx]
n_steps = max(m - 1, 0)
tasks_remaining_per_pair[pair_idx] = n_steps
for step in range(1, m):
output_specs.append(('interp', len(tasks)))
tasks.append((pair_idx, step / m))
output_specs.append(('orig', n_input - 1))

# Flat array to hold each interpolated frame result, indexed by task position.
interp_results: typing.List[typing.Optional[torch.Tensor]] = [None] * len(tasks)

frames_processed_since_cache_clear = 0
pos = 0
Expand Down Expand Up @@ -196,33 +243,35 @@ def vfi(
ensemble,
).clamp(0, 1).detach().cpu()

for idx, (pair_idx, _) in enumerate(batch_tasks):
results[pair_idx].append(middle_frames[idx : idx + 1].to(dtype=torch_dtype))
tasks_remaining_per_pair[pair_idx] -= 1
# Clear cache after finishing each pair's worth of tasks
if tasks_remaining_per_pair[pair_idx] == 0:
frames_processed_since_cache_clear += 1
if frames_processed_since_cache_clear >= clear_cache_after_n_frames:
print("Comfy-VFI: Clearing cache...", end=' ')
soft_empty_cache()
gc.collect()
frames_processed_since_cache_clear = 0
print("Done cache clearing")
for i, (pair_idx, _) in enumerate(batch_tasks):
task_idx = pos + i
interp_results[task_idx] = middle_frames[i : i + 1].to(dtype=torch_dtype)

if not use_fps_mode:
tasks_remaining_per_pair[pair_idx] -= 1
if tasks_remaining_per_pair[pair_idx] == 0:
frames_processed_since_cache_clear += 1
if frames_processed_since_cache_clear >= clear_cache_after_n_frames:
print("Comfy-VFI: Clearing cache...", end=' ')
soft_empty_cache()
gc.collect()
frames_processed_since_cache_clear = 0
print("Done cache clearing")

pos += len(batch_tasks)

# Assemble output: each original frame followed by its interpolated frames
# Assemble output frames in order using output_specs
output_frames: typing.List[torch.Tensor] = []
for pair_idx in range(n_pairs):
output_frames.append(frames[pair_idx : pair_idx + 1].to(dtype=torch_dtype))
for mid in results[pair_idx]:
output_frames.append(mid)
output_frames.append(frames[-1:].to(dtype=torch_dtype))
for spec in output_specs:
if spec[0] == 'orig':
output_frames.append(frames[spec[1] : spec[1] + 1].to(dtype=torch_dtype))
else:
output_frames.append(interp_results[spec[1]])

print("Comfy-VFI: Final clearing cache...", end=' ')
soft_empty_cache()
print("Done cache clearing")
print(f"Comfy-VFI done! {sum(len(v) for v in results.values()) + len(frames)} frames generated")
print(f"Comfy-VFI done! {len(output_frames)} frames generated")

# Always return float32 — numpy and all downstream ComfyUI nodes require it
out_tensor = torch.cat(output_frames, dim=0).to(torch.float32)
Expand Down