From 5e74092b2a7909117445f2d66812a7dba72063f3 Mon Sep 17 00:00:00 2001 From: marduk191 Date: Thu, 5 Mar 2026 07:10:27 -0600 Subject: [PATCH] added source and target framerate option Co-Authored-By: Claude Sonnet 4.6 --- vfi_models/rife/__init__.py | 137 ++++++++++++++++++++++++------------ 1 file changed, 93 insertions(+), 44 deletions(-) diff --git a/vfi_models/rife/__init__.py b/vfi_models/rife/__init__.py index 8c75362..52a5e9c 100644 --- a/vfi_models/rife/__init__.py +++ b/vfi_models/rife/__init__.py @@ -41,7 +41,11 @@ def INPUT_TYPES(s): ), "frames": ("IMAGE", ), "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), - "multiplier": ("INT", {"default": 2, "min": 1}), + "fps_mode": ("BOOLEAN", {"default": False, + "tooltip": "When enabled, use source_fps and target_fps to determine output frame count " + "instead of multiplier."}), + "multiplier": ("INT", {"default": 2, "min": 1, + "tooltip": "Used when fps_mode is off. Multiplies each input frame pair by this factor."}), "fast_mode": ("BOOLEAN", {"default": True}), "ensemble": ("BOOLEAN", {"default": True}), "scale_factor": ([0.25, 0.5, 1.0, 2.0, 4.0], {"default": 1.0}), @@ -64,7 +68,13 @@ def INPUT_TYPES(s): "Set to 1 for the most conservative behaviour."}), }, "optional": { - "optional_interpolation_states": ("INTERPOLATION_STATES", ) + "optional_interpolation_states": ("INTERPOLATION_STATES", ), + "source_fps": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 300.0, "step": 0.001, + "tooltip": "Frame rate of the input frames. " + "Set both source_fps and target_fps (>0) to use FPS mode instead of multiplier."}), + "target_fps": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 300.0, "step": 0.001, + "tooltip": "Desired output frame rate. " + "Set both source_fps and target_fps (>0) to use FPS mode instead of multiplier."}), } } @@ -77,6 +87,7 @@ def vfi( ckpt_name: typing.AnyStr, frames: torch.Tensor, clear_cache_after_n_frames: int = 10, + fps_mode: bool = False, multiplier: typing.SupportsInt = 2, fast_mode: bool = False, ensemble: bool = False, @@ -85,6 +96,8 @@ def vfi( torch_compile: bool = False, batch_size: int = 1, optional_interpolation_states: InterpolationStateList = None, + source_fps: float = 0.0, + target_fps: float = 0.0, **kwargs ): """ @@ -139,32 +152,66 @@ def vfi( frames = preprocess_frames(frames) - # Normalise multiplier to a per-pair list - n_pairs = len(frames) - 1 - if isinstance(multiplier, int): - multipliers = [int(multiplier)] * n_pairs - else: - multipliers = list(map(int, multiplier)) - multipliers += [2] * (n_pairs - len(multipliers)) - + n_input = len(frames) + n_pairs = n_input - 1 scale_list = [8 / scale_factor, 4 / scale_factor, 2 / scale_factor, 1 / scale_factor] - # Build a flat list of all (pair_idx, timestep) tasks, skipping excluded pairs. - # Each task produces exactly one intermediate frame. + # output_specs: describes every output frame in order. + # ('orig', frame_idx) — copy frames[frame_idx] directly + # ('interp', task_idx) — result of tasks[task_idx] + output_specs: typing.List[typing.Tuple] = [] tasks: typing.List[typing.Tuple[int, float]] = [] - tasks_remaining_per_pair: typing.Dict[int, int] = {} - for pair_idx in range(n_pairs): - if optional_interpolation_states is not None and optional_interpolation_states.is_frame_skipped(pair_idx): - tasks_remaining_per_pair[pair_idx] = 0 - continue - m = multipliers[pair_idx] - n_steps = max(m - 1, 0) - tasks_remaining_per_pair[pair_idx] = n_steps - for step in range(1, m): - tasks.append((pair_idx, step / m)) - - # Storage for intermediate frames, keyed by pair index - results: typing.Dict[int, typing.List[torch.Tensor]] = {i: [] for i in range(n_pairs)} + + use_fps_mode = fps_mode and source_fps > 0 and target_fps > 0 + + if use_fps_mode: + # FPS mode: place output frames at exact target_fps timestamps. + if abs(source_fps - target_fps) < 0.01: + print("Comfy-VFI: source_fps ≈ target_fps, returning frames unchanged.") + return (postprocess_frames(frames.to(torch.float32)),) + + n_output = max(2, round((n_input - 1) * target_fps / source_fps) + 1) + print(f"Comfy-VFI: FPS mode {source_fps} → {target_fps} fps ({n_input} → {n_output} frames)") + + for out_i in range(n_output): + # Map output frame index to a continuous position in input-frame space + input_pos = out_i * (n_input - 1) / (n_output - 1) + pair_idx = int(input_pos) + alpha = input_pos - pair_idx + + if pair_idx >= n_pairs or alpha < 1e-6: + # Exact input frame (or past the last) + output_specs.append(('orig', min(pair_idx, n_input - 1))) + elif optional_interpolation_states is not None and optional_interpolation_states.is_frame_skipped(pair_idx): + output_specs.append(('orig', pair_idx)) + else: + output_specs.append(('interp', len(tasks))) + tasks.append((pair_idx, alpha)) + + else: + # Multiplier mode: insert (multiplier-1) evenly-spaced frames between each pair. + if isinstance(multiplier, int): + multipliers = [int(multiplier)] * n_pairs + else: + multipliers = list(map(int, multiplier)) + multipliers += [2] * (n_pairs - len(multipliers)) + + tasks_remaining_per_pair: typing.Dict[int, int] = {} + for pair_idx in range(n_pairs): + output_specs.append(('orig', pair_idx)) + if optional_interpolation_states is not None and optional_interpolation_states.is_frame_skipped(pair_idx): + tasks_remaining_per_pair[pair_idx] = 0 + continue + m = multipliers[pair_idx] + n_steps = max(m - 1, 0) + tasks_remaining_per_pair[pair_idx] = n_steps + for step in range(1, m): + output_specs.append(('interp', len(tasks))) + tasks.append((pair_idx, step / m)) + output_specs.append(('orig', n_input - 1)) + + # Flat array to hold each interpolated frame result, indexed by task position. + interp_results: typing.List[typing.Optional[torch.Tensor]] = [None] * len(tasks) frames_processed_since_cache_clear = 0 pos = 0 @@ -196,33 +243,35 @@ def vfi( ensemble, ).clamp(0, 1).detach().cpu() - for idx, (pair_idx, _) in enumerate(batch_tasks): - results[pair_idx].append(middle_frames[idx : idx + 1].to(dtype=torch_dtype)) - tasks_remaining_per_pair[pair_idx] -= 1 - # Clear cache after finishing each pair's worth of tasks - if tasks_remaining_per_pair[pair_idx] == 0: - frames_processed_since_cache_clear += 1 - if frames_processed_since_cache_clear >= clear_cache_after_n_frames: - print("Comfy-VFI: Clearing cache...", end=' ') - soft_empty_cache() - gc.collect() - frames_processed_since_cache_clear = 0 - print("Done cache clearing") + for i, (pair_idx, _) in enumerate(batch_tasks): + task_idx = pos + i + interp_results[task_idx] = middle_frames[i : i + 1].to(dtype=torch_dtype) + + if not use_fps_mode: + tasks_remaining_per_pair[pair_idx] -= 1 + if tasks_remaining_per_pair[pair_idx] == 0: + frames_processed_since_cache_clear += 1 + if frames_processed_since_cache_clear >= clear_cache_after_n_frames: + print("Comfy-VFI: Clearing cache...", end=' ') + soft_empty_cache() + gc.collect() + frames_processed_since_cache_clear = 0 + print("Done cache clearing") pos += len(batch_tasks) - # Assemble output: each original frame followed by its interpolated frames + # Assemble output frames in order using output_specs output_frames: typing.List[torch.Tensor] = [] - for pair_idx in range(n_pairs): - output_frames.append(frames[pair_idx : pair_idx + 1].to(dtype=torch_dtype)) - for mid in results[pair_idx]: - output_frames.append(mid) - output_frames.append(frames[-1:].to(dtype=torch_dtype)) + for spec in output_specs: + if spec[0] == 'orig': + output_frames.append(frames[spec[1] : spec[1] + 1].to(dtype=torch_dtype)) + else: + output_frames.append(interp_results[spec[1]]) print("Comfy-VFI: Final clearing cache...", end=' ') soft_empty_cache() print("Done cache clearing") - print(f"Comfy-VFI done! {sum(len(v) for v in results.values()) + len(frames)} frames generated") + print(f"Comfy-VFI done! {len(output_frames)} frames generated") # Always return float32 — numpy and all downstream ComfyUI nodes require it out_tensor = torch.cat(output_frames, dim=0).to(torch.float32)