@@ -129,22 +129,38 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
129129 return img_byte_arr
130130
131131
132+ def _compute_downscale_dims (src_w : int , src_h : int , total_pixels : int ) -> tuple [int , int ] | None :
133+ """Return downscaled (w, h) with even dims fitting ``total_pixels``, or None if already fits.
134+
135+ Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions
136+ are rounded down to even values (many codecs require divisible-by-2).
137+ """
138+ pixels = src_w * src_h
139+ if pixels <= total_pixels :
140+ return None
141+ scale = math .sqrt (total_pixels / pixels )
142+ new_w = max (2 , int (src_w * scale ))
143+ new_h = max (2 , int (src_h * scale ))
144+ new_w -= new_w % 2
145+ new_h -= new_h % 2
146+ return new_w , new_h
147+
148+
132149def downscale_image_tensor (image : torch .Tensor , total_pixels : int = 1536 * 1024 ) -> torch .Tensor :
133- """Downscale input image tensor to roughly the specified total pixels."""
150+ """Downscale input image tensor to roughly the specified total pixels.
151+
152+ Output dimensions are rounded down to even values so that the result is guaranteed to fit within ``total_pixels``
153+ and is compatible with codecs that require even dimensions (e.g. yuv420p).
154+ """
134155 samples = image .movedim (- 1 , 1 )
135- total = int (total_pixels )
136- scale_by = math .sqrt (total / (samples .shape [3 ] * samples .shape [2 ]))
137- if scale_by >= 1 :
156+ dims = _compute_downscale_dims (samples .shape [3 ], samples .shape [2 ], int (total_pixels ))
157+ if dims is None :
138158 return image
139- width = round (samples .shape [3 ] * scale_by )
140- height = round (samples .shape [2 ] * scale_by )
141-
142- s = common_upscale (samples , width , height , "lanczos" , "disabled" )
143- s = s .movedim (1 , - 1 )
144- return s
159+ new_w , new_h = dims
160+ return common_upscale (samples , new_w , new_h , "lanczos" , "disabled" ).movedim (1 , - 1 )
145161
146162
147- def downscale_image_tensor_by_max_side (image : torch .Tensor , * , max_side : int ) -> torch .Tensor :
163+ def downscale_image_tensor_by_max_side (image : torch .Tensor , * , max_side : int ) -> torch .Tensor :
148164 """Downscale input image tensor so the largest dimension is at most max_side pixels."""
149165 samples = image .movedim (- 1 , 1 )
150166 height , width = samples .shape [2 ], samples .shape [3 ]
@@ -399,6 +415,72 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
399415 raise RuntimeError (f"Failed to trim video: { str (e )} " ) from e
400416
401417
418+ def resize_video_to_pixel_budget (video : Input .Video , total_pixels : int ) -> Input .Video :
419+ """Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio.
420+
421+ Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio.
422+ Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
423+ """
424+ src_w , src_h = video .get_dimensions ()
425+ scale_dims = _compute_downscale_dims (src_w , src_h , total_pixels )
426+ if scale_dims is None :
427+ return video
428+ return _apply_video_scale (video , scale_dims )
429+
430+
431+ def _apply_video_scale (video : Input .Video , scale_dims : tuple [int , int ]) -> Input .Video :
432+ """Re-encode ``video`` scaled to ``scale_dims`` with a single decode/encode pass."""
433+ out_w , out_h = scale_dims
434+ output_buffer = BytesIO ()
435+ input_container = None
436+ output_container = None
437+
438+ try :
439+ input_source = video .get_stream_source ()
440+ input_container = av .open (input_source , mode = "r" )
441+ output_container = av .open (output_buffer , mode = "w" , format = "mp4" )
442+
443+ video_stream = output_container .add_stream ("h264" , rate = video .get_frame_rate ())
444+ video_stream .width = out_w
445+ video_stream .height = out_h
446+ video_stream .pix_fmt = "yuv420p"
447+
448+ audio_stream = None
449+ for stream in input_container .streams :
450+ if isinstance (stream , av .AudioStream ):
451+ audio_stream = output_container .add_stream ("aac" , rate = stream .sample_rate )
452+ audio_stream .sample_rate = stream .sample_rate
453+ audio_stream .layout = stream .layout
454+ break
455+
456+ for frame in input_container .decode (video = 0 ):
457+ frame = frame .reformat (width = out_w , height = out_h , format = "yuv420p" )
458+ for packet in video_stream .encode (frame ):
459+ output_container .mux (packet )
460+ for packet in video_stream .encode ():
461+ output_container .mux (packet )
462+
463+ if audio_stream is not None :
464+ input_container .seek (0 )
465+ for audio_frame in input_container .decode (audio = 0 ):
466+ for packet in audio_stream .encode (audio_frame ):
467+ output_container .mux (packet )
468+ for packet in audio_stream .encode ():
469+ output_container .mux (packet )
470+
471+ output_container .close ()
472+ input_container .close ()
473+ output_buffer .seek (0 )
474+ return InputImpl .VideoFromFile (output_buffer )
475+
476+ except Exception as e :
477+ if input_container is not None :
478+ input_container .close ()
479+ if output_container is not None :
480+ output_container .close ()
481+ raise RuntimeError (f"Failed to resize video: { str (e )} " ) from e
482+
483+
402484def _f32_pcm (wav : torch .Tensor ) -> torch .Tensor :
403485 """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
404486 if wav .dtype .is_floating_point :
0 commit comments