Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations
import logging, math, torch
from comfy_api.latest import io

logger = logging.getLogger(__name__)


class SwarmVideoResampleFPS(io.ComfyNode):
MIN_FPS: float = 0.01
MAX_FPS: float = 99999.9
STEP_FPS: float = 1.0
DEFAULT_FPS_OUT: float = 24.0
METHOD_LINEAR: str = "linear"
METHOD_NEAREST: str = "nearest"

@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SwarmVideoResampleFPS",
display_name="Swarm Video Resample FPS",
category="SwarmUI/video",
description="Resample a video from fps_in to fps_out while preserving total duration.",
inputs=[
io.Image.Input("images", tooltip="The images to resample."),
io.Float.Input("fps_in", min=cls.MIN_FPS, max=cls.MAX_FPS, step=cls.STEP_FPS, tooltip="Source frame rate."),
io.Float.Input("fps_out", default=cls.DEFAULT_FPS_OUT, min=cls.MIN_FPS, max=cls.MAX_FPS, step=cls.STEP_FPS, tooltip="Target frame rate."),
io.Combo.Input("method", options=[cls.METHOD_LINEAR, cls.METHOD_NEAREST], default=cls.METHOD_LINEAR,
tooltip=(
"linear: each output frame is a linear blend of the two source frames bracketing its timestamp. Equivalent to ffmpeg's framerate filter. Slightly more expensive; avoids the duplicated-frame artifact. See https://ffmpeg.org/ffmpeg-filters.html#framerate\n"
"nearest: each output frame is the source frame closest in time. Equivalent to ffmpeg's fps filter. Cheap; can produce visible judder on pans. See https://ffmpeg.org/ffmpeg-filters.html#fps-1"
),
),
],
outputs=[io.Image.Output("images")],
)

@classmethod
@torch.inference_mode()
def execute(cls, images: torch.Tensor, fps_in: float, fps_out: float, method: str) -> io.NodeOutput:
if fps_in <= 0 or fps_out <= 0:
raise ValueError(f"SwarmVideoResampleFPS: fps_in and fps_out must be positive (got {fps_in}, {fps_out})")

frame_count_in = int(images.shape[0])
if frame_count_in <= 1 or math.isclose(fps_in, fps_out):
return io.NodeOutput(images)

# Compute output frame count and the fractional source-frame position for each output frame: 4 frames @ 2fps -> 4fps yields 8 frames at source positions [0, 0.5, 1.0, ..., 3.5]
frame_count_out = max(1, round(frame_count_in / fps_in * fps_out))
source_positions = torch.arange(frame_count_out, dtype=torch.float64, device=images.device) / fps_out * fps_in

if method == cls.METHOD_NEAREST:
resampled = cls._sample_nearest(images, source_positions)
else:
resampled = cls._sample_linear(images, source_positions)

logger.info(f"SwarmVideoResampleFPS: {frame_count_in} frames @ {fps_in} fps -> {frame_count_out} frames @ {fps_out} fps ({method})")
return io.NodeOutput(resampled)

@classmethod
def _sample_nearest(cls, source_frames: torch.Tensor, source_positions: torch.Tensor) -> torch.Tensor:
"""Pick the closest source frame for each fractional position.

See https://ffmpeg.org/ffmpeg-filters.html#fps-1
"""
nearest_idx = source_positions.round().long()
last_valid_idx = source_frames.shape[0] - 1
nearest_idx = torch.clamp(nearest_idx, 0, last_valid_idx)
return source_frames[nearest_idx].contiguous()

@classmethod
def _sample_linear(cls, source_frames: torch.Tensor, source_positions: torch.Tensor) -> torch.Tensor:
"""Linearly blend the two source frames bracketing each fractional position.

See https://ffmpeg.org/ffmpeg-filters.html#framerate
"""
last_valid_idx = source_frames.shape[0] - 1
lower_idx = torch.clamp(source_positions.floor().long(), 0, last_valid_idx)
upper_idx = torch.clamp(lower_idx + 1, 0, last_valid_idx)
blend_weight = (source_positions - lower_idx.to(torch.float64)).to(source_frames.dtype)
while blend_weight.ndim < source_frames.ndim:
blend_weight = blend_weight.unsqueeze(-1)

lower_frames = source_frames[lower_idx]
upper_frames = source_frames[upper_idx]
return ((1.0 - blend_weight) * lower_frames + blend_weight * upper_frames).contiguous()


NODE_CLASS_MAPPINGS = {
"SwarmVideoResampleFPS": SwarmVideoResampleFPS,
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os, folder_paths, traceback

from . import SwarmBlending, SwarmImages, SwarmInternalUtil, SwarmKSampler, SwarmLoadImageB64, SwarmLoraLoader, SwarmMasks, SwarmSaveImageWS, SwarmTiling, SwarmExtractLora, SwarmUnsampler, SwarmLatents, SwarmInputNodes, SwarmTextHandling, SwarmReference, SwarmMath, SwarmSam2, SwarmAudio
from . import SwarmBlending, SwarmImages, SwarmInternalUtil, SwarmKSampler, SwarmLoadImageB64, SwarmLoraLoader, SwarmMasks, SwarmSaveImageWS, SwarmTiling, SwarmExtractLora, SwarmUnsampler, SwarmLatents, SwarmInputNodes, SwarmTextHandling, SwarmReference, SwarmMath, SwarmSam2, SwarmAudio, SwarmVideo

WEB_DIRECTORY = "./web"

Expand All @@ -23,6 +23,7 @@
| SwarmMath.NODE_CLASS_MAPPINGS
| SwarmSam2.NODE_CLASS_MAPPINGS
| SwarmAudio.NODE_CLASS_MAPPINGS
| SwarmVideo.NODE_CLASS_MAPPINGS
)

try:
Expand Down
12 changes: 11 additions & 1 deletion src/BuiltinExtensions/ComfyUIBackend/WGNodeData.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,17 @@ public class WGNodeData(JArray _path, WorkflowGenerator _gen, string _dataType,
public int? Frames = null;

/// <summary>The frames per second of a video, if known and valid.</summary>
public int? FPS = null;
public JToken FPS = null;

/// <summary>Returns the FPS as an int, or null if it is a node-ref or unset.</summary>
public int? GetRawFPS()
{
if (FPS is JValue v && v.Type == JTokenType.Integer)
{
return v.Value<int>();
}
return null;
}

/// <summary>If this is a video data object, and audio is separate but tracked, this is the audio associated.</summary>
public WGNodeData AttachedAudio = null;
Expand Down
4 changes: 3 additions & 1 deletion src/BuiltinExtensions/ComfyUIBackend/WorkflowGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ public WGNodeData LoadImage(ImageFile img, string param, bool resize, string nod
else
{
WGNodeData attachedAudio = null;
JToken fpsRef = null;
if (img.Type.MetaType == MediaMetaType.Video)
{
result = CreateNode("SwarmLoadVideoB64", new JObject()
Expand All @@ -473,6 +474,7 @@ public WGNodeData LoadImage(ImageFile img, string param, bool resize, string nod
});
result = splitNode;
attachedAudio = new([splitNode, 1], this, WGNodeData.DT_AUDIO, CurrentCompat());
fpsRef = NodePath(splitNode, 2);
}
else
{
Expand All @@ -496,7 +498,7 @@ public WGNodeData LoadImage(ImageFile img, string param, bool resize, string nod
["crop"] = "disabled"
}, nodeId);
}
return new([result, 0], this, WGNodeData.DT_VIDEO, CurrentCompat()) { AttachedAudio = attachedAudio, Width = imgWidth, Height = imgHeight };
return new([result, 0], this, WGNodeData.DT_VIDEO, CurrentCompat()) { AttachedAudio = attachedAudio, Width = imgWidth, Height = imgHeight, FPS = fpsRef };
}
}
else
Expand Down
15 changes: 14 additions & 1 deletion src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,19 @@ bool getBestFor(string phrase)
}
if (preprocessor.ToLowerFast() != "none")
{
if (imageNodeActual.DataType == WGNodeData.DT_VIDEO && imageNodeActual.FPS is not null)
{
int fps = g.Text2VideoFPS();
string resampleNode = g.CreateNode("SwarmVideoResampleFPS", new JObject()
{
["images"] = imageNodeActual.Path,
["fps_in"] = imageNodeActual.FPS,
["fps_out"] = fps,
["method"] = "linear"
});
imageNodeActual = imageNodeActual.WithPath([resampleNode, 0]);
imageNodeActual.FPS = fps;
}
JArray preprocActual = g.CreatePreprocessor(preprocessor, imageNodeActual);
g.NodeHelpers["controlnet_preprocessor"] = $"{preprocActual[0]}";
imageNodeActual = imageNodeActual.WithPath(preprocActual);
Expand Down Expand Up @@ -1867,7 +1880,7 @@ void RunSegmentationProcessing(WorkflowGenerator g, bool isBeforeRefiner)
}
JArray newInterp = g.DoInterpolation(g.CurrentMedia.Path, method, mult);
g.CurrentMedia = g.CurrentMedia.WithPath(newInterp);
int fps = g.CurrentMedia.FPS ?? g.Text2VideoFPS();
int fps = g.CurrentMedia.GetRawFPS() ?? g.Text2VideoFPS();
fps *= mult;
g.CurrentMedia.FPS = fps;
g.T2VFPSOverride = fps;
Expand Down
Loading