Skip to content

Commit 8718623

Browse files
Adds SwarmVideoResampleFPS; resamples controlnet preview videos (#1387)
* Adds SwarmVideoResampleFPS; resamples controlnet preview videos * PR feedback * Collapse _source_positions() * Splits MAX_FPS_IN and MAX_FPS_OUT * Formatting * Code review feedback * Lower MIN_FPS * PR feedback: FPS to JToken; breaking change * fixes --------- Co-authored-by: Alex "mcmonkey" Goodwin <git_commits@alexgoodwin.dev>
1 parent b0b9610 commit 8718623

5 files changed

Lines changed: 120 additions & 4 deletions

File tree

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from __future__ import annotations
2+
import logging, math, torch
3+
from comfy_api.latest import io
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
class SwarmVideoResampleFPS(io.ComfyNode):
9+
MIN_FPS: float = 0.01
10+
MAX_FPS: float = 99999.9
11+
STEP_FPS: float = 1.0
12+
DEFAULT_FPS_OUT: float = 24.0
13+
METHOD_LINEAR: str = "linear"
14+
METHOD_NEAREST: str = "nearest"
15+
16+
@classmethod
17+
def define_schema(cls) -> io.Schema:
18+
return io.Schema(
19+
node_id="SwarmVideoResampleFPS",
20+
display_name="Swarm Video Resample FPS",
21+
category="SwarmUI/video",
22+
description="Resample a video from fps_in to fps_out while preserving total duration.",
23+
inputs=[
24+
io.Image.Input("images", tooltip="The images to resample."),
25+
io.Float.Input("fps_in", min=cls.MIN_FPS, max=cls.MAX_FPS, step=cls.STEP_FPS, tooltip="Source frame rate."),
26+
io.Float.Input("fps_out", default=cls.DEFAULT_FPS_OUT, min=cls.MIN_FPS, max=cls.MAX_FPS, step=cls.STEP_FPS, tooltip="Target frame rate."),
27+
io.Combo.Input("method", options=[cls.METHOD_LINEAR, cls.METHOD_NEAREST], default=cls.METHOD_LINEAR,
28+
tooltip=(
29+
"linear: each output frame is a linear blend of the two source frames bracketing its timestamp. Equivalent to ffmpeg's framerate filter. Slightly more expensive; avoids the duplicated-frame artifact. See https://ffmpeg.org/ffmpeg-filters.html#framerate\n"
30+
"nearest: each output frame is the source frame closest in time. Equivalent to ffmpeg's fps filter. Cheap; can produce visible judder on pans. See https://ffmpeg.org/ffmpeg-filters.html#fps-1"
31+
),
32+
),
33+
],
34+
outputs=[io.Image.Output("images")],
35+
)
36+
37+
@classmethod
38+
@torch.inference_mode()
39+
def execute(cls, images: torch.Tensor, fps_in: float, fps_out: float, method: str) -> io.NodeOutput:
40+
if fps_in <= 0 or fps_out <= 0:
41+
raise ValueError(f"SwarmVideoResampleFPS: fps_in and fps_out must be positive (got {fps_in}, {fps_out})")
42+
43+
frame_count_in = int(images.shape[0])
44+
if frame_count_in <= 1 or math.isclose(fps_in, fps_out):
45+
return io.NodeOutput(images)
46+
47+
# Compute output frame count and the fractional source-frame position for each output frame: 4 frames @ 2fps -> 4fps yields 8 frames at source positions [0, 0.5, 1.0, ..., 3.5]
48+
frame_count_out = max(1, round(frame_count_in / fps_in * fps_out))
49+
source_positions = torch.arange(frame_count_out, dtype=torch.float64, device=images.device) / fps_out * fps_in
50+
51+
if method == cls.METHOD_NEAREST:
52+
resampled = cls._sample_nearest(images, source_positions)
53+
else:
54+
resampled = cls._sample_linear(images, source_positions)
55+
56+
logger.info(f"SwarmVideoResampleFPS: {frame_count_in} frames @ {fps_in} fps -> {frame_count_out} frames @ {fps_out} fps ({method})")
57+
return io.NodeOutput(resampled)
58+
59+
@classmethod
60+
def _sample_nearest(cls, source_frames: torch.Tensor, source_positions: torch.Tensor) -> torch.Tensor:
61+
"""Pick the closest source frame for each fractional position.
62+
63+
See https://ffmpeg.org/ffmpeg-filters.html#fps-1
64+
"""
65+
nearest_idx = source_positions.round().long()
66+
last_valid_idx = source_frames.shape[0] - 1
67+
nearest_idx = torch.clamp(nearest_idx, 0, last_valid_idx)
68+
return source_frames[nearest_idx].contiguous()
69+
70+
@classmethod
71+
def _sample_linear(cls, source_frames: torch.Tensor, source_positions: torch.Tensor) -> torch.Tensor:
72+
"""Linearly blend the two source frames bracketing each fractional position.
73+
74+
See https://ffmpeg.org/ffmpeg-filters.html#framerate
75+
"""
76+
last_valid_idx = source_frames.shape[0] - 1
77+
lower_idx = torch.clamp(source_positions.floor().long(), 0, last_valid_idx)
78+
upper_idx = torch.clamp(lower_idx + 1, 0, last_valid_idx)
79+
blend_weight = (source_positions - lower_idx.to(torch.float64)).to(source_frames.dtype)
80+
while blend_weight.ndim < source_frames.ndim:
81+
blend_weight = blend_weight.unsqueeze(-1)
82+
83+
lower_frames = source_frames[lower_idx]
84+
upper_frames = source_frames[upper_idx]
85+
return ((1.0 - blend_weight) * lower_frames + blend_weight * upper_frames).contiguous()
86+
87+
88+
NODE_CLASS_MAPPINGS = {
89+
"SwarmVideoResampleFPS": SwarmVideoResampleFPS,
90+
}

src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/SwarmComfyCommon/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os, folder_paths, traceback
22

3-
from . import SwarmBlending, SwarmImages, SwarmInternalUtil, SwarmKSampler, SwarmLoadImageB64, SwarmLoraLoader, SwarmMasks, SwarmSaveImageWS, SwarmTiling, SwarmExtractLora, SwarmUnsampler, SwarmLatents, SwarmInputNodes, SwarmTextHandling, SwarmReference, SwarmMath, SwarmSam2, SwarmAudio
3+
from . import SwarmBlending, SwarmImages, SwarmInternalUtil, SwarmKSampler, SwarmLoadImageB64, SwarmLoraLoader, SwarmMasks, SwarmSaveImageWS, SwarmTiling, SwarmExtractLora, SwarmUnsampler, SwarmLatents, SwarmInputNodes, SwarmTextHandling, SwarmReference, SwarmMath, SwarmSam2, SwarmAudio, SwarmVideo
44

55
WEB_DIRECTORY = "./web"
66

@@ -23,6 +23,7 @@
2323
| SwarmMath.NODE_CLASS_MAPPINGS
2424
| SwarmSam2.NODE_CLASS_MAPPINGS
2525
| SwarmAudio.NODE_CLASS_MAPPINGS
26+
| SwarmVideo.NODE_CLASS_MAPPINGS
2627
)
2728

2829
try:

src/BuiltinExtensions/ComfyUIBackend/WGNodeData.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,17 @@ public class WGNodeData(JArray _path, WorkflowGenerator _gen, string _dataType,
4848
public int? Frames = null;
4949

5050
/// <summary>The frames per second of a video, if known and valid.</summary>
51-
public int? FPS = null;
51+
public JToken FPS = null;
52+
53+
/// <summary>Returns the FPS as an int, or null if it is a node-ref or unset.</summary>
54+
public int? GetRawFPS()
55+
{
56+
if (FPS is JValue v && v.Type == JTokenType.Integer)
57+
{
58+
return v.Value<int>();
59+
}
60+
return null;
61+
}
5262

5363
/// <summary>If this is a video data object, and audio is separate but tracked, this is the audio associated.</summary>
5464
public WGNodeData AttachedAudio = null;

src/BuiltinExtensions/ComfyUIBackend/WorkflowGenerator.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ public WGNodeData LoadImage(ImageFile img, string param, bool resize, string nod
461461
else
462462
{
463463
WGNodeData attachedAudio = null;
464+
JToken fpsRef = null;
464465
if (img.Type.MetaType == MediaMetaType.Video)
465466
{
466467
result = CreateNode("SwarmLoadVideoB64", new JObject()
@@ -473,6 +474,7 @@ public WGNodeData LoadImage(ImageFile img, string param, bool resize, string nod
473474
});
474475
result = splitNode;
475476
attachedAudio = new([splitNode, 1], this, WGNodeData.DT_AUDIO, CurrentCompat());
477+
fpsRef = NodePath(splitNode, 2);
476478
}
477479
else
478480
{
@@ -496,7 +498,7 @@ public WGNodeData LoadImage(ImageFile img, string param, bool resize, string nod
496498
["crop"] = "disabled"
497499
}, nodeId);
498500
}
499-
return new([result, 0], this, WGNodeData.DT_VIDEO, CurrentCompat()) { AttachedAudio = attachedAudio, Width = imgWidth, Height = imgHeight };
501+
return new([result, 0], this, WGNodeData.DT_VIDEO, CurrentCompat()) { AttachedAudio = attachedAudio, Width = imgWidth, Height = imgHeight, FPS = fpsRef };
500502
}
501503
}
502504
else

src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,19 @@ bool getBestFor(string phrase)
10791079
}
10801080
if (preprocessor.ToLowerFast() != "none")
10811081
{
1082+
if (imageNodeActual.DataType == WGNodeData.DT_VIDEO && imageNodeActual.FPS is not null)
1083+
{
1084+
int fps = g.Text2VideoFPS();
1085+
string resampleNode = g.CreateNode("SwarmVideoResampleFPS", new JObject()
1086+
{
1087+
["images"] = imageNodeActual.Path,
1088+
["fps_in"] = imageNodeActual.FPS,
1089+
["fps_out"] = fps,
1090+
["method"] = "linear"
1091+
});
1092+
imageNodeActual = imageNodeActual.WithPath([resampleNode, 0]);
1093+
imageNodeActual.FPS = fps;
1094+
}
10821095
JArray preprocActual = g.CreatePreprocessor(preprocessor, imageNodeActual);
10831096
g.NodeHelpers["controlnet_preprocessor"] = $"{preprocActual[0]}";
10841097
imageNodeActual = imageNodeActual.WithPath(preprocActual);
@@ -1867,7 +1880,7 @@ void RunSegmentationProcessing(WorkflowGenerator g, bool isBeforeRefiner)
18671880
}
18681881
JArray newInterp = g.DoInterpolation(g.CurrentMedia.Path, method, mult);
18691882
g.CurrentMedia = g.CurrentMedia.WithPath(newInterp);
1870-
int fps = g.CurrentMedia.FPS ?? g.Text2VideoFPS();
1883+
int fps = g.CurrentMedia.GetRawFPS() ?? g.Text2VideoFPS();
18711884
fps *= mult;
18721885
g.CurrentMedia.FPS = fps;
18731886
g.T2VFPSOverride = fps;

0 commit comments

Comments
 (0)