@@ -91,42 +91,45 @@ def execute(cls, images: torch.Tensor, fps_in: float, fps_out: float, method: st
9191
9292 @classmethod
9393 def _source_positions (cls , frame_count_out : int , fps_in : float , fps_out : float , device : torch .device ) -> torch .Tensor :
94- """Fractional source-frame index for each output frame .
94+ """For each output frame, the (fractional) source-frame index it should display .
9595
96- Each output frame should display what the source had at the same
97- timestamp. The output frame at index i plays at time i / fps_out, and
98- the source frame visible at that time is at index i * (fps_in / fps_out).
96+ Computed in two steps:
97+ 1. Convert each output-frame index into the timestamp (in seconds) at
98+ which that frame will be shown: timestamp = index / fps_out.
99+ 2. Convert that timestamp into the source-frame index visible at the
100+ same moment in the original video: source_index = timestamp * fps_in.
99101 """
100102 output_indices = torch .arange (frame_count_out , dtype = torch .float64 , device = device )
101- return output_indices * (fps_in / fps_out )
103+ output_timestamps_sec = output_indices / fps_out
104+ return output_timestamps_sec * fps_in
102105
103106 @classmethod
104107 def _sample_nearest (cls , source_frames : torch .Tensor , source_positions : torch .Tensor ) -> torch .Tensor :
105108 """Pick the closest source frame for each fractional position.
106-
109+
107110 See https://ffmpeg.org/ffmpeg-filters.html#fps-1
108111 """
109- last_idx = source_frames .shape [0 ] - 1
110- nearest_idx = torch .clamp (source_positions .round ().long (), 0 , last_idx )
112+ nearest_idx = source_positions .round ().long ()
113+ last_valid_idx = source_frames .shape [0 ] - 1
114+ nearest_idx = torch .clamp (nearest_idx , 0 , last_valid_idx )
111115 return source_frames [nearest_idx ].contiguous ()
112116
113117 @classmethod
114118 def _sample_linear (cls , source_frames : torch .Tensor , source_positions : torch .Tensor ) -> torch .Tensor :
115119 """Linearly blend the two source frames bracketing each fractional position.
116-
120+
117121 See https://ffmpeg.org/ffmpeg-filters.html#framerate
118122 """
119- last_idx = source_frames .shape [0 ] - 1
120- lower_idx = torch .clamp (source_positions .floor ().long (), 0 , last_idx )
121- upper_idx = torch .clamp (lower_idx + 1 , 0 , last_idx )
122-
123+ last_valid_idx = source_frames .shape [0 ] - 1
124+ lower_idx = torch .clamp (source_positions .floor ().long (), 0 , last_valid_idx )
125+ upper_idx = torch .clamp (lower_idx + 1 , 0 , last_valid_idx )
123126 blend_weight = (source_positions - lower_idx .to (torch .float64 )).to (source_frames .dtype )
124- # Reshape weight to [N_out, 1, 1, ...] so it broadcasts across the H/W/C
125- # dims of the per-frame tensors during the blend.
126- broadcast_shape = (- 1 ,) + (1 ,) * (source_frames .ndim - 1 )
127- blend_weight = blend_weight .view (* broadcast_shape )
127+ while blend_weight .ndim < source_frames .ndim :
128+ blend_weight = blend_weight .unsqueeze (- 1 )
128129
129- return ((1.0 - blend_weight ) * source_frames [lower_idx ] + blend_weight * source_frames [upper_idx ]).contiguous ()
130+ lower_frames = source_frames [lower_idx ]
131+ upper_frames = source_frames [upper_idx ]
132+ return ((1.0 - blend_weight ) * lower_frames + blend_weight * upper_frames ).contiguous ()
130133
131134
132135NODE_CLASS_MAPPINGS = {
0 commit comments