Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/ltx2/ltxt_i2av.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,25 @@
seed = 42
image_path = "/path/to/woman.jpeg" # For multiple images, use comma-separated paths: "path1.jpg,path2.jpg"
image_strength = 1.0 # Scalar: use same strength for all images, or list: [1.0, 0.8] for different strengths
# Pixel frame index per image (optional). If None, indices are evenly spaced in [0, num_frames-1] (see create_generator num_frames).
# Example for 3 images and num_frames=121: omit image_frame_idx to get ~[0, 60, 120], or set explicitly:
# image_frame_idx = [0, 40, 120]
image_frame_idx = None
prompt = "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance."
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
save_result_path = "/path/to/save_results/output.mp4"

# Note: image_strength can also be set in config_json
# For scalar: image_strength = 1.0 (all images use same strength)
# For list: image_strength = [1.0, 0.8] (must match number of images)
# image_frame_idx: list of pixel-frame indices (one per image), or None for even spacing across the clip

pipe.generate(
seed=seed,
prompt=prompt,
image_path=image_path,
image_strength=image_strength,
image_frame_idx=image_frame_idx,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
2 changes: 2 additions & 0 deletions examples/ltx2/ltxt_i2av_distilled_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
seed = 42
image_path = "/path/to/LightX2V/assets/inputs/imgs/woman.jpeg" # For multiple images, use comma-separated paths: "path1.jpg,path2.jpg"
image_strength = 1.0 # Scalar: use same strength for all images, or list: [1.0, 0.8] for different strengths
image_frame_idx = None # Or e.g. [0, 60, 120] — pixel frame per image; None = evenly spaced in [0, num_frames-1]
prompt = "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance."
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
save_result_path = "/path/to/LightX2V/save_results/output_lightx2v_ltx2_i2av_distilled_fp8.mp4"
Expand All @@ -50,6 +51,7 @@
prompt=prompt,
image_path=image_path,
image_strength=image_strength,
image_frame_idx=image_frame_idx,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
5 changes: 4 additions & 1 deletion lightx2v/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ def main():
)
parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task")
parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file or directory for audio-to-video (s2v) task")
parser.add_argument("--image_strength", type=float, default=1.0, help="The strength of the image-to-audio-video (i2av) task")
parser.add_argument("--image_strength", type=str, default="1.0", help="i2av: single float, or comma-separated floats (one per image, or one value broadcast). Example: 1.0 or 1.0,0.85,0.9")
parser.add_argument(
"--image_frame_idx", type=str, default="", help="i2av: comma-separated pixel frame indices (one per image). Omit or empty to evenly space frames in [0, num_frames-1]. Example: 0,40,80"
)
# [Warning] For vace task, need refactor.
parser.add_argument(
"--src_ref_images",
Expand Down
32 changes: 21 additions & 11 deletions lightx2v/models/runners/default_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.global_paras import CALIB
from lightx2v.utils.profiler import *
from lightx2v.utils.utils import get_optimal_patched_size_with_sp, isotropic_crop_resize, mux_audio_from_video, save_to_video, wan_vae_to_comfy
from lightx2v.utils.utils import get_optimal_patched_size_with_sp, isotropic_crop_resize, mux_audio_from_video, save_to_image, save_to_video, wan_vae_to_comfy
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)
Expand Down Expand Up @@ -454,16 +454,26 @@ def process_images_after_vae_decoder(self):
fps = self.config.get("fps", 16)

if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(f"🎬 Start to save video 🎬")

save_to_video(self.gen_video_final, self.input_info.save_result_path, fps=fps, method="ffmpeg")
if self.config.get("task") == "sr":
input_video_path = getattr(self.input_info, "video_path", "")
if input_video_path:
muxed_path = mux_audio_from_video(input_video_path, self.input_info.save_result_path)
if muxed_path:
logger.info(f"Audio muxed from input video: {input_video_path}")
logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅")
out_path = self.input_info.save_result_path
img_in = (getattr(self.input_info, "image_path", None) or "").strip()
vid_in = (getattr(self.input_info, "video_path", None) or "").strip()
sr_from_image_only = self.config.get("task") == "sr" and bool(img_in) and not bool(vid_in)

if sr_from_image_only:
logger.info("🖼 Start to save SR image (image_path input, no video_path) 🖼")
save_to_image(self.gen_video_final, out_path)
logger.info(f"✅ Image saved successfully to: {out_path} ✅")
else:
logger.info(f"🎬 Start to save video 🎬")

save_to_video(self.gen_video_final, out_path, fps=fps, method="ffmpeg")
if self.config.get("task") == "sr":
input_video_path = getattr(self.input_info, "video_path", "")
if input_video_path:
muxed_path = mux_audio_from_video(input_video_path, out_path)
if muxed_path:
logger.info(f"Audio muxed from input video: {input_video_path}")
logger.info(f"✅ Video saved successfully to: {out_path} ✅")
return {"video": None}

@ProfilingContext4DebugL1("RUN pipeline", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_worker_request_duration, metrics_labels=["DefaultRunner"])
Expand Down
137 changes: 119 additions & 18 deletions lightx2v/models/runners/ltx2/ltx2_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,49 @@
torch_device_module = getattr(torch, AI_DEVICE)


def _ltx2_parse_image_paths(image_path: str) -> list[str]:
return [p.strip() for p in image_path.split(",") if p.strip()]


def _ltx2_normalize_image_strengths(image_strength, n: int) -> list[float]:
if not isinstance(image_strength, list):
return [float(image_strength)] * n
if len(image_strength) == 1:
return [float(image_strength[0])] * n
if len(image_strength) != n:
raise ValueError(f"i2av image_strength: expected 1 or {n} values (scalar or list), got length {len(image_strength)}")
return [float(x) for x in image_strength]


def _ltx2_resolve_pixel_frame_indices(image_frame_idx, n: int, num_frames: int) -> list[int]:
if not image_frame_idx:
if n == 1:
return [0]
if num_frames <= 1:
return [0] * n
return [round(i * (num_frames - 1) / (n - 1)) for i in range(n)]
if len(image_frame_idx) != n:
raise ValueError(f"i2av image_frame_idx: expected {n} indices (one per image), got {len(image_frame_idx)}")
hi = num_frames - 1
return [max(0, min(hi, int(x))) for x in image_frame_idx]


def _ltx2_pixel_to_latent_frame_idx(pixel_frame_idx: int, temporal_scale: int) -> int:
if pixel_frame_idx == 0:
return 0
return (pixel_frame_idx - 1) // temporal_scale + 1


def _ltx2_resize_video_denoise_mask_for_stage2(mask: torch.Tensor, target_h: int, target_w: int) -> torch.Tensor:
"""Resize stage-1 unpatchified video denoise mask to stage-2 latent spatial size."""
# mask shape: [1, F, H, W] -> [F, 1, H, W] for 2D interpolation
m = mask.to(dtype=torch.float32)
m = m.permute(1, 0, 2, 3)
m = torch.nn.functional.interpolate(m, size=(target_h, target_w), mode="nearest")
# back to [1, F, H, W]
return m.permute(1, 0, 2, 3).contiguous()


@RUNNER_REGISTER("ltx2")
class LTX2Runner(DefaultRunner):
def __init__(self, config):
Expand Down Expand Up @@ -146,14 +189,15 @@ def get_latent_shape_with_target_hw(self):
target_width = self.config["target_width"]
self.input_info.target_shape = [target_height, target_width]

target_video_length = self.input_info.target_video_length or self.config["target_video_length"]
video_latent_shape = (
self.config.get("num_channels_latents", 128),
(self.config["target_video_length"] - 1) // self.config["vae_scale_factors"][0] + 1,
(target_video_length - 1) // self.config["vae_scale_factors"][0] + 1,
int(target_height) // self.config["vae_scale_factors"][1],
int(target_width) // self.config["vae_scale_factors"][2],
)

duration = float(self.config["target_video_length"]) / float(self.config["fps"])
duration = float(target_video_length) / float(self.config["fps"])
latents_per_second = float(self.config["audio_sampling_rate"]) / float(self.config["audio_hop_length"]) / float(self.config["audio_scale_factor"])
audio_frames = round(duration * latents_per_second)

Expand All @@ -178,8 +222,26 @@ def _run_input_encoder_local_t2av(self):
"image_encoder_output": None,
}

def _normalize_i2av_input_fields(self) -> None:
info = self.input_info
if isinstance(info.image_strength, str):
p = [float(x.strip()) for x in info.image_strength.split(",") if x.strip()]
info.image_strength = 1.0 if not p else (p[0] if len(p) == 1 else p)
if isinstance(info.image_frame_idx, str):
p = [int(x.strip()) for x in info.image_frame_idx.split(",") if x.strip()]
info.image_frame_idx = p or None
n = len(_ltx2_parse_image_paths(info.image_path or ""))
if n == 0:
return
st, fi = info.image_strength, info.image_frame_idx
if isinstance(st, list) and len(st) not in (1, n):
raise ValueError(f"i2av image_strength: need 1 or {n} values, got {len(st)}")
if fi is not None and len(fi) != n:
raise ValueError(f"i2av image_frame_idx: need {n} indices, got {len(fi)}")

@ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_i2av(self):
self._normalize_i2av_input_fields()
self.input_info.video_latent_shape, self.input_info.audio_latent_shape = self.get_latent_shape_with_target_hw()
text_encoder_output = self.run_text_encoder(self.input_info)
self.video_denoise_mask, self.initial_video_latent = self.run_vae_encoder()
Expand Down Expand Up @@ -231,14 +293,27 @@ def run_vae_encoder(self):
device=AI_DEVICE,
)

# Process each image conditioning
image_paths = self.input_info.image_path.split(",") # image_path1,image_path2,image_path3
for frame_idx, image_path in enumerate(image_paths):
if not isinstance(self.input_info.image_strength, list):
strength = self.input_info.image_strength
else:
strength = self.input_info.image_strength[frame_idx]
logger.info(f" 📷 Loading image: {image_path} for frame {frame_idx} with strength {strength}")
image_paths = _ltx2_parse_image_paths(self.input_info.image_path)
n = len(image_paths)
if n == 0:
logger.warning("i2av: image_path is empty, skipping image conditioning")
self._i2av_guiding_keyframe_meta = None
torch_device_module.empty_cache()
gc.collect()
return video_denoise_mask, initial_video_latent

num_frames = self.input_info.target_video_length or self.config.get("target_video_length", 1)
strengths = _ltx2_normalize_image_strengths(self.input_info.image_strength, n)
raw_frame_idx = getattr(self.input_info, "image_frame_idx", None)
pixel_frame_indices = _ltx2_resolve_pixel_frame_indices(raw_frame_idx, n, num_frames)
temporal_scale = int(self.config["vae_scale_factors"][0])

guiding_keyframe_meta: list[tuple[str, int, float]] = []

for i, image_path in enumerate(image_paths):
strength = strengths[i]
pixel_frame_idx = pixel_frame_indices[i]
logger.info(f" 📷 Loading image: {image_path} pixel_frame={pixel_frame_idx} strength={strength} ({i + 1}/{n})")

# Load and preprocess image
image = load_image_conditioning(
Expand All @@ -254,18 +329,15 @@ def run_vae_encoder(self):

encoded_latent = encoded_latent.squeeze(0)

# Verify frame index is valid
if frame_idx < 0 or frame_idx >= F:
logger.warning(f"⚠️ Frame index {frame_idx} out of range [0, {F - 1}], skipping")
# Pixel frame 0 → write into the latent time slot; other frames → guiding tokens appended in the scheduler.
if pixel_frame_idx != 0:
guiding_keyframe_meta.append((image_path, pixel_frame_idx, strength))
continue

# Get the latent frame index by converting pixel frame to latent frame
# For LTX2, temporal compression is 8x, so latent_frame_idx = (frame_idx - 1) // 8 + 1 for frame_idx > 0
# or 0 for frame_idx == 0
if frame_idx == 0:
latent_frame_idx = 0
else:
latent_frame_idx = (frame_idx - 1) // self.config["vae_scale_factors"][0] + 1
latent_frame_idx = _ltx2_pixel_to_latent_frame_idx(pixel_frame_idx, temporal_scale)

if latent_frame_idx >= F:
logger.warning(f"⚠️ Latent frame index {latent_frame_idx} out of range [0, {F - 1}], skipping")
Expand All @@ -281,6 +353,7 @@ def run_vae_encoder(self):
video_denoise_mask[:, latent_frame_idx, :, :] = 1.0 - strength

logger.info(f" ✓ Encoded image to latent frame {latent_frame_idx}")
self._i2av_guiding_keyframe_meta = guiding_keyframe_meta

torch_device_module.empty_cache()
gc.collect()
Expand All @@ -289,6 +362,26 @@ def run_vae_encoder(self):

return video_denoise_mask, initial_video_latent

def _build_i2av_video_guiding_latents(self):
"""Encode guiding keyframe images at current target_shape for scheduler.append (stage 1 / 2)."""
meta = getattr(self, "_i2av_guiding_keyframe_meta", None)
if not meta:
return None
th, tw = self.input_info.target_shape[0], self.input_info.target_shape[1]
out = []
for path, pixel_idx, strength in meta:
image = load_image_conditioning(
image_path=path,
height=th,
width=tw,
dtype=GET_DTYPE(),
device=AI_DEVICE,
)
with torch.no_grad():
enc = self.video_vae.encode(image).squeeze(0)
out.append((enc, pixel_idx, strength))
return out

@ProfilingContext4DebugL1(
"Run Text Encoder",
recorder_mode=GET_RECORDER_MODE(),
Expand Down Expand Up @@ -365,12 +458,16 @@ def run_upsampler(self, v_latent, a_latent):

self.input_info.target_shape = [self.input_info.target_shape[0] * 2, self.input_info.target_shape[1] * 2]
self.input_info.video_latent_shape, self.input_info.audio_latent_shape = self.get_latent_shape_with_target_hw()
_, _, stage2_h, stage2_w = self.input_info.video_latent_shape
stage2_video_denoise_mask = None
if hasattr(self, "video_denoise_mask") and self.video_denoise_mask is not None:
stage2_video_denoise_mask = _ltx2_resize_video_denoise_mask_for_stage2(self.video_denoise_mask, stage2_h, stage2_w)

# Prepare scheduler using the shared method
self._prepare_scheduler(
initial_video_latent=upsampled_v_latent, # Use upsampled video latent
initial_audio_latent=a_latent, # Keep audio from stage 1 (aligned with distilled.py:183)
video_denoise_mask=None, # Stage 2 fully denoises, no mask needed
video_denoise_mask=stage2_video_denoise_mask, # Keep keyframe constraints in stage 2
noise_scale=upsample_distilled_sigmas[0].item(), # Use first sigma as noise_scale (aligned with distilled.py:181)
)

Expand Down Expand Up @@ -425,6 +522,10 @@ def _prepare_scheduler(
if noise_scale is not None:
prepare_kwargs["noise_scale"] = noise_scale

vg = self._build_i2av_video_guiding_latents()
if vg:
prepare_kwargs["video_guiding_latents"] = vg

self.model.scheduler.prepare(**prepare_kwargs)

def init_run(self):
Expand Down
Loading
Loading