Skip to content

Commit 27171c9

Browse files
committed
Adds SwarmVideoResampleFPS; resamples controlnet preview videos
1 parent b56108b commit 27171c9

1 file changed

Lines changed: 21 additions & 18 deletions

File tree

  • src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/SwarmComfyCommon

src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/SwarmComfyCommon/SwarmVideo.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -91,42 +91,45 @@ def execute(cls, images: torch.Tensor, fps_in: float, fps_out: float, method: st
9191

9292
@classmethod
9393
def _source_positions(cls, frame_count_out: int, fps_in: float, fps_out: float, device: torch.device) -> torch.Tensor:
94-
"""Fractional source-frame index for each output frame.
94+
"""For each output frame, the (fractional) source-frame index it should display.
9595
96-
Each output frame should display what the source had at the same
97-
timestamp. The output frame at index i plays at time i / fps_out, and
98-
the source frame visible at that time is at index i * (fps_in / fps_out).
96+
Computed in two steps:
97+
1. Convert each output-frame index into the timestamp (in seconds) at
98+
which that frame will be shown: timestamp = index / fps_out.
99+
2. Convert that timestamp into the source-frame index visible at the
100+
same moment in the original video: source_index = timestamp * fps_in.
99101
"""
100102
output_indices = torch.arange(frame_count_out, dtype=torch.float64, device=device)
101-
return output_indices * (fps_in / fps_out)
103+
output_timestamps_sec = output_indices / fps_out
104+
return output_timestamps_sec * fps_in
102105

103106
@classmethod
104107
def _sample_nearest(cls, source_frames: torch.Tensor, source_positions: torch.Tensor) -> torch.Tensor:
105108
"""Pick the closest source frame for each fractional position.
106-
109+
107110
See https://ffmpeg.org/ffmpeg-filters.html#fps-1
108111
"""
109-
last_idx = source_frames.shape[0] - 1
110-
nearest_idx = torch.clamp(source_positions.round().long(), 0, last_idx)
112+
nearest_idx = source_positions.round().long()
113+
last_valid_idx = source_frames.shape[0] - 1
114+
nearest_idx = torch.clamp(nearest_idx, 0, last_valid_idx)
111115
return source_frames[nearest_idx].contiguous()
112116

113117
@classmethod
114118
def _sample_linear(cls, source_frames: torch.Tensor, source_positions: torch.Tensor) -> torch.Tensor:
115119
"""Linearly blend the two source frames bracketing each fractional position.
116-
120+
117121
See https://ffmpeg.org/ffmpeg-filters.html#framerate
118122
"""
119-
last_idx = source_frames.shape[0] - 1
120-
lower_idx = torch.clamp(source_positions.floor().long(), 0, last_idx)
121-
upper_idx = torch.clamp(lower_idx + 1, 0, last_idx)
122-
123+
last_valid_idx = source_frames.shape[0] - 1
124+
lower_idx = torch.clamp(source_positions.floor().long(), 0, last_valid_idx)
125+
upper_idx = torch.clamp(lower_idx + 1, 0, last_valid_idx)
123126
blend_weight = (source_positions - lower_idx.to(torch.float64)).to(source_frames.dtype)
124-
# Reshape weight to [N_out, 1, 1, ...] so it broadcasts across the H/W/C
125-
# dims of the per-frame tensors during the blend.
126-
broadcast_shape = (-1,) + (1,) * (source_frames.ndim - 1)
127-
blend_weight = blend_weight.view(*broadcast_shape)
127+
while blend_weight.ndim < source_frames.ndim:
128+
blend_weight = blend_weight.unsqueeze(-1)
128129

129-
return ((1.0 - blend_weight) * source_frames[lower_idx] + blend_weight * source_frames[upper_idx]).contiguous()
130+
lower_frames = source_frames[lower_idx]
131+
upper_frames = source_frames[upper_idx]
132+
return ((1.0 - blend_weight) * lower_frames + blend_weight * upper_frames).contiguous()
130133

131134

132135
NODE_CLASS_MAPPINGS = {

0 commit comments

Comments
 (0)