diff --git a/examples/ltx2/ltxt_i2av.py b/examples/ltx2/ltxt_i2av.py index c66b8402d..98c197cd5 100755 --- a/examples/ltx2/ltxt_i2av.py +++ b/examples/ltx2/ltxt_i2av.py @@ -26,6 +26,10 @@ seed = 42 image_path = "/path/to/woman.jpeg" # For multiple images, use comma-separated paths: "path1.jpg,path2.jpg" image_strength = 1.0 # Scalar: use same strength for all images, or list: [1.0, 0.8] for different strengths +# Pixel frame index per image (optional). If None, indices are evenly spaced in [0, num_frames-1] (see create_generator num_frames). +# Example for 3 images and num_frames=121: omit image_frame_idx to get ~[0, 60, 120], or set explicitly: +# image_frame_idx = [0, 40, 120] +image_frame_idx = None prompt = "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." save_result_path = "/path/to/save_results/output.mp4" @@ -33,12 +37,14 @@ # Note: image_strength can also be set in config_json # For scalar: image_strength = 1.0 (all images use same strength) # For list: image_strength = [1.0, 0.8] (must match number of images) +# image_frame_idx: list of pixel-frame indices (one per image), or None for even spacing across the clip pipe.generate( seed=seed, prompt=prompt, image_path=image_path, image_strength=image_strength, + image_frame_idx=image_frame_idx, negative_prompt=negative_prompt, save_result_path=save_result_path, ) diff --git a/examples/ltx2/ltxt_i2av_distilled_fp8.py b/examples/ltx2/ltxt_i2av_distilled_fp8.py index 41aa6314c..832ce552d 100755 --- a/examples/ltx2/ltxt_i2av_distilled_fp8.py +++ b/examples/ltx2/ltxt_i2av_distilled_fp8.py @@ -38,6 +38,7 @@ seed = 42 image_path = "/path/to/LightX2V/assets/inputs/imgs/woman.jpeg" # For multiple images, use comma-separated paths: "path1.jpg,path2.jpg" image_strength = 1.0 # Scalar: use same strength for all images, or list: [1.0, 0.8] for different strengths +image_frame_idx = None # Or e.g. [0, 60, 120] β€” pixel frame per image; None = evenly spaced in [0, num_frames-1] prompt = "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." save_result_path = "/path/to/LightX2V/save_results/output_lightx2v_ltx2_i2av_distilled_fp8.mp4" @@ -50,6 +51,7 @@ prompt=prompt, image_path=image_path, image_strength=image_strength, + image_frame_idx=image_frame_idx, negative_prompt=negative_prompt, save_result_path=save_result_path, ) diff --git a/lightx2v/infer.py b/lightx2v/infer.py index 2ebd770b4..87fe63eaa 100755 --- a/lightx2v/infer.py +++ b/lightx2v/infer.py @@ -104,7 +104,10 @@ def main(): ) parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task") parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file or directory for audio-to-video (s2v) task") - parser.add_argument("--image_strength", type=float, default=1.0, help="The strength of the image-to-audio-video (i2av) task") + parser.add_argument("--image_strength", type=str, default="1.0", help="i2av: single float, or comma-separated floats (one per image, or one value broadcast). Example: 1.0 or 1.0,0.85,0.9") + parser.add_argument( + "--image_frame_idx", type=str, default="", help="i2av: comma-separated pixel frame indices (one per image). Omit or empty to evenly space frames in [0, num_frames-1]. Example: 0,40,80" + ) # [Warning] For vace task, need refactor. parser.add_argument( "--src_ref_images", diff --git a/lightx2v/models/runners/default_runner.py b/lightx2v/models/runners/default_runner.py index b76634a3e..ef7a3b4b8 100755 --- a/lightx2v/models/runners/default_runner.py +++ b/lightx2v/models/runners/default_runner.py @@ -15,7 +15,7 @@ from lightx2v.utils.generate_task_id import generate_task_id from lightx2v.utils.global_paras import CALIB from lightx2v.utils.profiler import * -from lightx2v.utils.utils import get_optimal_patched_size_with_sp, isotropic_crop_resize, mux_audio_from_video, save_to_video, wan_vae_to_comfy +from lightx2v.utils.utils import get_optimal_patched_size_with_sp, isotropic_crop_resize, mux_audio_from_video, save_to_image, save_to_video, wan_vae_to_comfy from lightx2v_platform.base.global_var import AI_DEVICE torch_device_module = getattr(torch, AI_DEVICE) @@ -454,16 +454,26 @@ def process_images_after_vae_decoder(self): fps = self.config.get("fps", 16) if not dist.is_initialized() or dist.get_rank() == 0: - logger.info(f"🎬 Start to save video 🎬") - - save_to_video(self.gen_video_final, self.input_info.save_result_path, fps=fps, method="ffmpeg") - if self.config.get("task") == "sr": - input_video_path = getattr(self.input_info, "video_path", "") - if input_video_path: - muxed_path = mux_audio_from_video(input_video_path, self.input_info.save_result_path) - if muxed_path: - logger.info(f"Audio muxed from input video: {input_video_path}") - logger.info(f"βœ… Video saved successfully to: {self.input_info.save_result_path} βœ…") + out_path = self.input_info.save_result_path + img_in = (getattr(self.input_info, "image_path", None) or "").strip() + vid_in = (getattr(self.input_info, "video_path", None) or "").strip() + sr_from_image_only = self.config.get("task") == "sr" and bool(img_in) and not bool(vid_in) + + if sr_from_image_only: + logger.info("πŸ–Ό Start to save SR image (image_path input, no video_path) πŸ–Ό") + save_to_image(self.gen_video_final, out_path) + logger.info(f"βœ… Image saved successfully to: {out_path} βœ…") + else: + logger.info(f"🎬 Start to save video 🎬") + + save_to_video(self.gen_video_final, out_path, fps=fps, method="ffmpeg") + if self.config.get("task") == "sr": + input_video_path = getattr(self.input_info, "video_path", "") + if input_video_path: + muxed_path = mux_audio_from_video(input_video_path, out_path) + if muxed_path: + logger.info(f"Audio muxed from input video: {input_video_path}") + logger.info(f"βœ… Video saved successfully to: {out_path} βœ…") return {"video": None} @ProfilingContext4DebugL1("RUN pipeline", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_worker_request_duration, metrics_labels=["DefaultRunner"]) diff --git a/lightx2v/models/runners/ltx2/ltx2_runner.py b/lightx2v/models/runners/ltx2/ltx2_runner.py index 30a5c32c8..3d2411fe0 100755 --- a/lightx2v/models/runners/ltx2/ltx2_runner.py +++ b/lightx2v/models/runners/ltx2/ltx2_runner.py @@ -21,6 +21,49 @@ torch_device_module = getattr(torch, AI_DEVICE) +def _ltx2_parse_image_paths(image_path: str) -> list[str]: + return [p.strip() for p in image_path.split(",") if p.strip()] + + +def _ltx2_normalize_image_strengths(image_strength, n: int) -> list[float]: + if not isinstance(image_strength, list): + return [float(image_strength)] * n + if len(image_strength) == 1: + return [float(image_strength[0])] * n + if len(image_strength) != n: + raise ValueError(f"i2av image_strength: expected 1 or {n} values (scalar or list), got length {len(image_strength)}") + return [float(x) for x in image_strength] + + +def _ltx2_resolve_pixel_frame_indices(image_frame_idx, n: int, num_frames: int) -> list[int]: + if not image_frame_idx: + if n == 1: + return [0] + if num_frames <= 1: + return [0] * n + return [round(i * (num_frames - 1) / (n - 1)) for i in range(n)] + if len(image_frame_idx) != n: + raise ValueError(f"i2av image_frame_idx: expected {n} indices (one per image), got {len(image_frame_idx)}") + hi = num_frames - 1 + return [max(0, min(hi, int(x))) for x in image_frame_idx] + + +def _ltx2_pixel_to_latent_frame_idx(pixel_frame_idx: int, temporal_scale: int) -> int: + if pixel_frame_idx == 0: + return 0 + return (pixel_frame_idx - 1) // temporal_scale + 1 + + +def _ltx2_resize_video_denoise_mask_for_stage2(mask: torch.Tensor, target_h: int, target_w: int) -> torch.Tensor: + """Resize stage-1 unpatchified video denoise mask to stage-2 latent spatial size.""" + # mask shape: [1, F, H, W] -> [F, 1, H, W] for 2D interpolation + m = mask.to(dtype=torch.float32) + m = m.permute(1, 0, 2, 3) + m = torch.nn.functional.interpolate(m, size=(target_h, target_w), mode="nearest") + # back to [1, F, H, W] + return m.permute(1, 0, 2, 3).contiguous() + + @RUNNER_REGISTER("ltx2") class LTX2Runner(DefaultRunner): def __init__(self, config): @@ -146,14 +189,15 @@ def get_latent_shape_with_target_hw(self): target_width = self.config["target_width"] self.input_info.target_shape = [target_height, target_width] + target_video_length = self.input_info.target_video_length or self.config["target_video_length"] video_latent_shape = ( self.config.get("num_channels_latents", 128), - (self.config["target_video_length"] - 1) // self.config["vae_scale_factors"][0] + 1, + (target_video_length - 1) // self.config["vae_scale_factors"][0] + 1, int(target_height) // self.config["vae_scale_factors"][1], int(target_width) // self.config["vae_scale_factors"][2], ) - duration = float(self.config["target_video_length"]) / float(self.config["fps"]) + duration = float(target_video_length) / float(self.config["fps"]) latents_per_second = float(self.config["audio_sampling_rate"]) / float(self.config["audio_hop_length"]) / float(self.config["audio_scale_factor"]) audio_frames = round(duration * latents_per_second) @@ -178,8 +222,26 @@ def _run_input_encoder_local_t2av(self): "image_encoder_output": None, } + def _normalize_i2av_input_fields(self) -> None: + info = self.input_info + if isinstance(info.image_strength, str): + p = [float(x.strip()) for x in info.image_strength.split(",") if x.strip()] + info.image_strength = 1.0 if not p else (p[0] if len(p) == 1 else p) + if isinstance(info.image_frame_idx, str): + p = [int(x.strip()) for x in info.image_frame_idx.split(",") if x.strip()] + info.image_frame_idx = p or None + n = len(_ltx2_parse_image_paths(info.image_path or "")) + if n == 0: + return + st, fi = info.image_strength, info.image_frame_idx + if isinstance(st, list) and len(st) not in (1, n): + raise ValueError(f"i2av image_strength: need 1 or {n} values, got {len(st)}") + if fi is not None and len(fi) != n: + raise ValueError(f"i2av image_frame_idx: need {n} indices, got {len(fi)}") + @ProfilingContext4DebugL2("Run Encoders") def _run_input_encoder_local_i2av(self): + self._normalize_i2av_input_fields() self.input_info.video_latent_shape, self.input_info.audio_latent_shape = self.get_latent_shape_with_target_hw() text_encoder_output = self.run_text_encoder(self.input_info) self.video_denoise_mask, self.initial_video_latent = self.run_vae_encoder() @@ -231,14 +293,27 @@ def run_vae_encoder(self): device=AI_DEVICE, ) - # Process each image conditioning - image_paths = self.input_info.image_path.split(",") # image_path1,image_path2,image_path3 - for frame_idx, image_path in enumerate(image_paths): - if not isinstance(self.input_info.image_strength, list): - strength = self.input_info.image_strength - else: - strength = self.input_info.image_strength[frame_idx] - logger.info(f" πŸ“· Loading image: {image_path} for frame {frame_idx} with strength {strength}") + image_paths = _ltx2_parse_image_paths(self.input_info.image_path) + n = len(image_paths) + if n == 0: + logger.warning("i2av: image_path is empty, skipping image conditioning") + self._i2av_guiding_keyframe_meta = None + torch_device_module.empty_cache() + gc.collect() + return video_denoise_mask, initial_video_latent + + num_frames = self.input_info.target_video_length or self.config.get("target_video_length", 1) + strengths = _ltx2_normalize_image_strengths(self.input_info.image_strength, n) + raw_frame_idx = getattr(self.input_info, "image_frame_idx", None) + pixel_frame_indices = _ltx2_resolve_pixel_frame_indices(raw_frame_idx, n, num_frames) + temporal_scale = int(self.config["vae_scale_factors"][0]) + + guiding_keyframe_meta: list[tuple[str, int, float]] = [] + + for i, image_path in enumerate(image_paths): + strength = strengths[i] + pixel_frame_idx = pixel_frame_indices[i] + logger.info(f" πŸ“· Loading image: {image_path} pixel_frame={pixel_frame_idx} strength={strength} ({i + 1}/{n})") # Load and preprocess image image = load_image_conditioning( @@ -254,18 +329,15 @@ def run_vae_encoder(self): encoded_latent = encoded_latent.squeeze(0) - # Verify frame index is valid - if frame_idx < 0 or frame_idx >= F: - logger.warning(f"⚠️ Frame index {frame_idx} out of range [0, {F - 1}], skipping") + # Pixel frame 0 β†’ write into the latent time slot; other frames β†’ guiding tokens appended in the scheduler. + if pixel_frame_idx != 0: + guiding_keyframe_meta.append((image_path, pixel_frame_idx, strength)) continue # Get the latent frame index by converting pixel frame to latent frame # For LTX2, temporal compression is 8x, so latent_frame_idx = (frame_idx - 1) // 8 + 1 for frame_idx > 0 # or 0 for frame_idx == 0 - if frame_idx == 0: - latent_frame_idx = 0 - else: - latent_frame_idx = (frame_idx - 1) // self.config["vae_scale_factors"][0] + 1 + latent_frame_idx = _ltx2_pixel_to_latent_frame_idx(pixel_frame_idx, temporal_scale) if latent_frame_idx >= F: logger.warning(f"⚠️ Latent frame index {latent_frame_idx} out of range [0, {F - 1}], skipping") @@ -281,6 +353,7 @@ def run_vae_encoder(self): video_denoise_mask[:, latent_frame_idx, :, :] = 1.0 - strength logger.info(f" βœ“ Encoded image to latent frame {latent_frame_idx}") + self._i2av_guiding_keyframe_meta = guiding_keyframe_meta torch_device_module.empty_cache() gc.collect() @@ -289,6 +362,26 @@ def run_vae_encoder(self): return video_denoise_mask, initial_video_latent + def _build_i2av_video_guiding_latents(self): + """Encode guiding keyframe images at current target_shape for scheduler.append (stage 1 / 2).""" + meta = getattr(self, "_i2av_guiding_keyframe_meta", None) + if not meta: + return None + th, tw = self.input_info.target_shape[0], self.input_info.target_shape[1] + out = [] + for path, pixel_idx, strength in meta: + image = load_image_conditioning( + image_path=path, + height=th, + width=tw, + dtype=GET_DTYPE(), + device=AI_DEVICE, + ) + with torch.no_grad(): + enc = self.video_vae.encode(image).squeeze(0) + out.append((enc, pixel_idx, strength)) + return out + @ProfilingContext4DebugL1( "Run Text Encoder", recorder_mode=GET_RECORDER_MODE(), @@ -365,12 +458,16 @@ def run_upsampler(self, v_latent, a_latent): self.input_info.target_shape = [self.input_info.target_shape[0] * 2, self.input_info.target_shape[1] * 2] self.input_info.video_latent_shape, self.input_info.audio_latent_shape = self.get_latent_shape_with_target_hw() + _, _, stage2_h, stage2_w = self.input_info.video_latent_shape + stage2_video_denoise_mask = None + if hasattr(self, "video_denoise_mask") and self.video_denoise_mask is not None: + stage2_video_denoise_mask = _ltx2_resize_video_denoise_mask_for_stage2(self.video_denoise_mask, stage2_h, stage2_w) # Prepare scheduler using the shared method self._prepare_scheduler( initial_video_latent=upsampled_v_latent, # Use upsampled video latent initial_audio_latent=a_latent, # Keep audio from stage 1 (aligned with distilled.py:183) - video_denoise_mask=None, # Stage 2 fully denoises, no mask needed + video_denoise_mask=stage2_video_denoise_mask, # Keep keyframe constraints in stage 2 noise_scale=upsample_distilled_sigmas[0].item(), # Use first sigma as noise_scale (aligned with distilled.py:181) ) @@ -425,6 +522,10 @@ def _prepare_scheduler( if noise_scale is not None: prepare_kwargs["noise_scale"] = noise_scale + vg = self._build_i2av_video_guiding_latents() + if vg: + prepare_kwargs["video_guiding_latents"] = vg + self.model.scheduler.prepare(**prepare_kwargs) def init_run(self): diff --git a/lightx2v/models/schedulers/ltx2/scheduler.py b/lightx2v/models/schedulers/ltx2/scheduler.py index 9e67f4c11..fad98bdbd 100755 --- a/lightx2v/models/schedulers/ltx2/scheduler.py +++ b/lightx2v/models/schedulers/ltx2/scheduler.py @@ -10,7 +10,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple +from typing import List, Optional, Tuple import einops import torch @@ -334,7 +334,7 @@ def __init__(self, config): # Patchifier configuration self.video_patch_size = config.get("video_patch_size", 1) self.fps = config.get("fps", 24) # Frames per second for position calculation - self.video_scale_factors = config.get("video_scale_factors", (8, 32, 32)) # (time, height, width) + self.video_scale_factors = tuple(config.get("vae_scale_factors") or config.get("video_scale_factors", (8, 32, 32))) # (time, height, width); merged LTX config uses vae_scale_factors # Initialize patchifiers self.video_patchifier = VideoLatentPatchifier(patch_size=self.video_patch_size) @@ -382,6 +382,7 @@ def prepare( noise_scale: float = 1.0, video_denoise_mask: Optional[torch.Tensor] = None, audio_denoise_mask: Optional[torch.Tensor] = None, + video_guiding_latents: Optional[List[Tuple[torch.Tensor, int, float]]] = None, ): """ Prepare scheduler for inference. @@ -395,6 +396,9 @@ def prepare( noise_scale: Scale factor for noise video_denoise_mask: Optional denoise mask for video (unpatchified) audio_denoise_mask: Optional denoise mask for audio (unpatchified) + video_guiding_latents: Optional list of (encoded_latent [C,1,H,W], pixel_frame_idx, strength). + Each keyframe is patchified and concatenated after the main video tokens, with temporal positions + offset by `pixel_frame_idx` (see `_append_video_guiding_keyframes`). """ # Reset step state (important for stage 2 after stage 1) self.step_index = 0 @@ -402,6 +406,7 @@ def prepare( self.a_noise_pred = None self.mm_last_v_pred = None self.mm_last_a_pred = None + self._video_main_num_tokens = None # Initialize generator self.generator = torch.Generator(device=AI_DEVICE).manual_seed(seed) @@ -415,10 +420,11 @@ def prepare( noise_scale=noise_scale, video_denoise_mask=video_denoise_mask, audio_denoise_mask=audio_denoise_mask, + video_guiding_latents=video_guiding_latents, ) if self.sigmas is None: - self.set_timesteps(infer_steps=self.infer_steps) + self.set_timesteps(infer_steps=self.infer_steps, latent=self.video_latent_state.latent) def prepare_latents( self, @@ -429,6 +435,7 @@ def prepare_latents( noise_scale: float = 1.0, video_denoise_mask: Optional[torch.Tensor] = None, audio_denoise_mask: Optional[torch.Tensor] = None, + video_guiding_latents: Optional[List[Tuple[torch.Tensor, int, float]]] = None, ): """ Prepare initial latents for denoising and patchify them. @@ -458,6 +465,7 @@ def prepare_latents( noise_scale=noise_scale, video_denoise_mask=video_denoise_mask, dtype=GET_DTYPE(), + video_guiding_latents=video_guiding_latents, ) # Prepare audio latents @@ -476,6 +484,7 @@ def _prepare_video_latents( noise_scale: float = 1.0, video_denoise_mask: Optional[torch.Tensor] = None, dtype: torch.dtype = None, + video_guiding_latents: Optional[List[Tuple[torch.Tensor, int, float]]] = None, ): """ Prepare video latents for denoising. @@ -486,6 +495,7 @@ def _prepare_video_latents( noise_scale: Scale factor for noise video_denoise_mask: Optional denoise mask for video (unpatchified) dtype: Data type for latents + video_guiding_latents: Optional guiding latents (see prepare()). """ _, frames_v, height_v, width_v = video_latent_shape @@ -553,6 +563,9 @@ def _prepare_video_latents( positions_video[0, ...] = positions_video[0, ...] / self.fps positions_video = positions_video.to(dtype) + # Main-grid token count (before optional guiding keyframe append) + self._video_main_num_tokens = int(patchified_video_latent.shape[0]) + # Create video LatentState self.video_latent_state = LatentState( latent=patchified_video_latent, @@ -561,6 +574,68 @@ def _prepare_video_latents( clean_latent=patchified_clean_video_latent, ) + if video_guiding_latents: + self._append_video_guiding_keyframes( + video_guiding_latents, + height_v, + width_v, + dtype, + noise_scale, + ) + + def _append_video_guiding_keyframes( + self, + keyframes: List[Tuple[torch.Tensor, int, float]], + height_v: int, + width_v: int, + dtype: torch.dtype, + noise_scale: float, + ) -> None: + """ + Append extra keyframe tokens after the main grid: patchify each [C,1,H,W], set temporal positions + from `pixel_frame_idx` (then divide by fps), and concatenate latents, masks, and positions. + """ + st = self.video_latent_state + for enc, pixel_frame_idx, strength in keyframes: + if enc.dim() != 4: + raise ValueError(f"guiding latent must be [C,1,H,W], got shape {tuple(enc.shape)}") + c, f, h, w = enc.shape + if f != 1: + raise ValueError(f"guiding latent must have F=1, got F={f}") + if h != height_v or w != width_v: + raise ValueError(f"guiding latent spatial shape ({h},{w}) must match video_latent_shape ({height_v},{width_v})") + + patch_tokens = self.video_patchifier.patchify(enc) + tk = patch_tokens.shape[0] + + latent_coords = self.video_patchifier.get_patch_grid_bounds(1, height_v, width_v, AI_DEVICE) + pos_k = get_pixel_coords( + latent_coords, + self.video_scale_factors, + causal_fix=(pixel_frame_idx == 0), + ) + pos_k = pos_k.float() + pos_k[0, ...] = pos_k[0, ...] + float(pixel_frame_idx) + pos_k[0, ...] = pos_k[0, ...] / float(self.fps) + pos_k = pos_k.to(dtype) + + mask_k = torch.full((tk, 1), 1.0 - float(strength), device=AI_DEVICE, dtype=torch.float32) + clean_k = patch_tokens.clone() + + noise_k = torch.randn( + *patch_tokens.shape, + dtype=patch_tokens.dtype, + device=AI_DEVICE, + generator=self.generator, + ) + scaled_m = mask_k * noise_scale + noised_k = (noise_k * scaled_m + clean_k * (1 - scaled_m)).to(patch_tokens.dtype) + + st.latent = torch.cat([st.latent, noised_k], dim=0) + st.clean_latent = torch.cat([st.clean_latent, clean_k], dim=0) + st.denoise_mask = torch.cat([st.denoise_mask, mask_k], dim=0) + st.positions = torch.cat([st.positions, pos_k], dim=1) + def _prepare_audio_latents( self, audio_latent_shape: tuple, @@ -742,7 +817,11 @@ def step_post(self): channels_v, frames_v, height_v, width_v = self.video_latent_shape_orig channels_a, frames_a, mel_bins_a = self.audio_latent_shape_orig - self.video_latent_state.latent = self.video_patchifier.unpatchify(self.video_latent_state.latent, frames_v, height_v, width_v) + vl = self.video_latent_state.latent + main_n = getattr(self, "_video_main_num_tokens", None) + if main_n is not None and vl.shape[0] > main_n: + vl = vl[:main_n] + self.video_latent_state.latent = self.video_patchifier.unpatchify(vl, frames_v, height_v, width_v) self.audio_latent_state.latent = self.audio_patchifier.unpatchify(self.audio_latent_state.latent, channels=channels_a, mel_bins=mel_bins_a) def clear(self): @@ -755,6 +834,7 @@ def clear(self): self.mm_last_v_pred = None self.mm_last_a_pred = None self.sigmas = None + self._video_main_num_tokens = None def video_timesteps_from_mask(self) -> torch.Tensor: """Compute timesteps from a denoise mask and sigma value. diff --git a/lightx2v/pipeline.py b/lightx2v/pipeline.py index 57712e103..702bcf80d 100755 --- a/lightx2v/pipeline.py +++ b/lightx2v/pipeline.py @@ -415,6 +415,7 @@ def generate( image_path=None, video_path=None, # For SR task (video super-resolution) image_strength=None, + image_frame_idx=None, last_frame_path=None, audio_path=None, src_ref_images=None, @@ -427,6 +428,7 @@ def generate( # Run inference (following LightX2V pattern) # Note: image_path supports comma-separated paths for multiple images # image_strength can be a scalar (float/int) or a list matching the number of images + # image_frame_idx: optional list of pixel frame indices (one per image), or None to evenly space in [0, num_frames-1] self.seed = seed self.image_path = image_path self.video_path = video_path # For SR task @@ -442,6 +444,7 @@ def generate( self.return_result_tensor = return_result_tensor self.target_shape = target_shape self.image_strength = image_strength + self.image_frame_idx = image_frame_idx if task is not None: self.task = task self.modify_config({"task": self.task}) diff --git a/lightx2v/utils/input_info.py b/lightx2v/utils/input_info.py index ba5e7d64d..1a645ffa8 100755 --- a/lightx2v/utils/input_info.py +++ b/lightx2v/utils/input_info.py @@ -1,6 +1,6 @@ import inspect from dataclasses import MISSING, dataclass, field, fields, make_dataclass -from typing import Any +from typing import Any, Optional import torch @@ -232,6 +232,7 @@ class T2AVInputInfo: audio_latent_shape: list = field(default_factory=list) latent_shape: list = field(default_factory=list) target_shape: list = field(default_factory=list) + target_video_length: int = field(default_factory=int) @dataclass @@ -242,6 +243,7 @@ class I2AVInputInfo: negative_prompt: str = field(default_factory=str) image_path: str = field(default_factory=str) image_strength: float = field(default_factory=float) + image_frame_idx: Optional[list[int]] = None save_result_path: str = field(default_factory=str) return_result_tensor: bool = field(default_factory=lambda: False) # shape related @@ -250,6 +252,7 @@ class I2AVInputInfo: resized_shape: list = field(default_factory=list) latent_shape: list = field(default_factory=list) target_shape: list = field(default_factory=list) + target_video_length: int = field(default_factory=int) @dataclass diff --git a/lightx2v/utils/utils.py b/lightx2v/utils/utils.py index 7b668f921..fc55dc31e 100755 --- a/lightx2v/utils/utils.py +++ b/lightx2v/utils/utils.py @@ -11,6 +11,7 @@ import torch.distributed as dist import torchvision import torchvision.transforms.functional as TF +from PIL import Image from einops import rearrange from loguru import logger from torchvision.transforms import InterpolationMode @@ -258,6 +259,7 @@ def save_to_video( # Get ffmpeg executable from imageio_ffmpeg ffmpeg_exe = ffmpeg.get_ffmpeg_exe() + out_pix = output_pix_fmt or "yuv420p" if lossless: command = [ @@ -305,16 +307,16 @@ def save_to_video( "-vcodec", "libx264", "-pix_fmt", - output_pix_fmt, + out_pix, "-an", # No audio output_path, ] - # Run FFmpeg + # Run FFmpeg (stderr to DEVNULL: avoids pipe buffer deadlock; no need to capture for errors) process = subprocess.Popen( command, stdin=subprocess.PIPE, - stderr=subprocess.PIPE, + stderr=subprocess.DEVNULL, ) if process.stdin is None: @@ -333,13 +335,24 @@ def save_to_video( process.wait() if process.returncode != 0: - error_output = process.stderr.read().decode() if process.stderr else "Unknown error" - raise RuntimeError(f"FFmpeg failed with error: {error_output}") + raise RuntimeError("FFmpeg failed.") else: raise ValueError(f"Unknown save method: {method}") +def save_to_image(images: torch.Tensor, output_path: str) -> None: + """Save the first frame of ComfyUI tensor ``[N, H, W, C]`` (values in ``[0, 1]``) as a static image. + + Used for ``task=sr`` when conditioning comes from ``image_path`` only (no ``video_path``). + """ + assert images.dim() == 4 and images.shape[-1] == 3, "Input must be [N, H, W, C] with C=3" + frame = images[0].clamp(0, 1).cpu().numpy() + frame_u8 = (frame * 255.0).round().astype(np.uint8) + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + Image.fromarray(frame_u8).save(output_path) + + def mux_audio_from_video( source_video_path: str, target_video_path: str, diff --git a/scripts/ltx2/run_ltx2_i2av_keyframes.sh b/scripts/ltx2/run_ltx2_i2av_keyframes.sh new file mode 100755 index 000000000..82eade927 --- /dev/null +++ b/scripts/ltx2/run_ltx2_i2av_keyframes.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/LightX2V +model_path=Lightricks/LTX-2 + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls ltx2 \ +--task i2av \ +--image_path "${lightx2v_path}/assets/inputs/imgs/frame_1.png,${lightx2v_path}/assets/inputs/imgs/frame_2.png,${lightx2v_path}/assets/inputs/imgs/frame_3.png,${lightx2v_path}/assets/inputs/imgs/frame_4.png" \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/ltx2/ltx2.json \ +--prompt "A motorcyclist wearing a full black leather racing suit and helmet rides a Yamaha R1 on a dry asphalt mountain road curve at afternoon golden hour in a continuous four-stage action shot: stage one before corner entry with the bike upright, front wheel nearly straight, rider slightly forward and brake light glowing; stage two at apex transition with clear turn-in, medium lean angle, early rear slip and controlled counter-steer, body posture becoming more aggressive with visible dust and tire smoke; stage three exits the sharp right turn with the bike progressively returning upright, rider sitting up slightly, throttle opening aggressively, front wheel lifting subtly from acceleration, rear tire gripping and propelling forward while smoke dissipates, emphasizing exit drift acceleration, camera already moving toward a rear-three-quarter chase; stage four is a dedicated side-on beat: the same rider and bike held in a stable lateral profile view, camera tracking parallel to the road on the outside of the curve at roughly bike height, wheels, fairing, helmet, and lean line clearly readable against the mountain and asphalt, as if cutting from chase to a classic side tracking shot. Between stage three and four, interpolate the viewpoint smoothly: continue the exit energy and speed, then ease the camera from rear-side chase into pure side profile without a hard cutβ€”match direction of travel, horizon, and lighting so the handoff feels like one continuous take. Keep rider identity and bike appearance consistent across all stages, with strong temporal continuity, dynamic motion blur, and photorealistic detail. Emphasize realistic synchronized audio: deep engine rumble at approach, sharp downshift blips and brief backfire pops entering the turn, tire scrub and short skid noise at apex with wind rush, then a rising high-RPM engine roar and accelerating exhaust on exit, steady wind and tire noise under the side-on pass, no background music." \ +--negative_prompt "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_i2av_keyframes2.mp4 \ +--image_strength 1.0,0.7,0.7,0.7 \ +--image_frame_idx 0,120,240,360 \ +--target_video_length 361 diff --git a/scripts/seedvr2/run_seedvr2_3b_image_sr.sh b/scripts/seedvr2/run_seedvr2_3b_image_sr.sh new file mode 100644 index 000000000..4dce0704e --- /dev/null +++ b/scripts/seedvr2/run_seedvr2_3b_image_sr.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/LightX2V +model_path=/path/to/ByteDance-Seed/SeedVR2-3B + +image_path=${lightx2v_path}/assets/inputs/imgs/frame_1.png + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls seedvr2 \ +--task sr \ +--sr_ratio 2.0 \ +--image_path $image_path \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seedvr/seedvr2_3b.json \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seedvr2_image_sr.png