Skip to content

Commit 5937896

Browse files
authored
support ltx2.3 keyframes to videos (#990)
1 parent 73147d4 commit 5937896

File tree

11 files changed

+305
-40
lines changed

11 files changed

+305
-40
lines changed

examples/ltx2/ltxt_i2av.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,25 @@
2626
seed = 42
2727
image_path = "/path/to/woman.jpeg" # For multiple images, use comma-separated paths: "path1.jpg,path2.jpg"
2828
image_strength = 1.0 # Scalar: use same strength for all images, or list: [1.0, 0.8] for different strengths
29+
# Pixel frame index per image (optional). If None, indices are evenly spaced in [0, num_frames-1] (see create_generator num_frames).
30+
# Example for 3 images and num_frames=121: omit image_frame_idx to get ~[0, 60, 120], or set explicitly:
31+
# image_frame_idx = [0, 40, 120]
32+
image_frame_idx = None
2933
prompt = "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance."
3034
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
3135
save_result_path = "/path/to/save_results/output.mp4"
3236

3337
# Note: image_strength can also be set in config_json
3438
# For scalar: image_strength = 1.0 (all images use same strength)
3539
# For list: image_strength = [1.0, 0.8] (must match number of images)
40+
# image_frame_idx: list of pixel-frame indices (one per image), or None for even spacing across the clip
3641

3742
pipe.generate(
3843
seed=seed,
3944
prompt=prompt,
4045
image_path=image_path,
4146
image_strength=image_strength,
47+
image_frame_idx=image_frame_idx,
4248
negative_prompt=negative_prompt,
4349
save_result_path=save_result_path,
4450
)

examples/ltx2/ltxt_i2av_distilled_fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
seed = 42
3939
image_path = "/path/to/LightX2V/assets/inputs/imgs/woman.jpeg" # For multiple images, use comma-separated paths: "path1.jpg,path2.jpg"
4040
image_strength = 1.0 # Scalar: use same strength for all images, or list: [1.0, 0.8] for different strengths
41+
image_frame_idx = None # Or e.g. [0, 60, 120] — pixel frame per image; None = evenly spaced in [0, num_frames-1]
4142
prompt = "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance."
4243
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
4344
save_result_path = "/path/to/LightX2V/save_results/output_lightx2v_ltx2_i2av_distilled_fp8.mp4"
@@ -50,6 +51,7 @@
5051
prompt=prompt,
5152
image_path=image_path,
5253
image_strength=image_strength,
54+
image_frame_idx=image_frame_idx,
5355
negative_prompt=negative_prompt,
5456
save_result_path=save_result_path,
5557
)

lightx2v/infer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ def main():
104104
)
105105
parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task")
106106
parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file or directory for audio-to-video (s2v) task")
107-
parser.add_argument("--image_strength", type=float, default=1.0, help="The strength of the image-to-audio-video (i2av) task")
107+
parser.add_argument("--image_strength", type=str, default="1.0", help="i2av: single float, or comma-separated floats (one per image, or one value broadcast). Example: 1.0 or 1.0,0.85,0.9")
108+
parser.add_argument(
109+
"--image_frame_idx", type=str, default="", help="i2av: comma-separated pixel frame indices (one per image). Omit or empty to evenly space frames in [0, num_frames-1]. Example: 0,40,80"
110+
)
108111
# [Warning] For vace task, need refactor.
109112
parser.add_argument(
110113
"--src_ref_images",

lightx2v/models/runners/default_runner.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from lightx2v.utils.generate_task_id import generate_task_id
1616
from lightx2v.utils.global_paras import CALIB
1717
from lightx2v.utils.profiler import *
18-
from lightx2v.utils.utils import get_optimal_patched_size_with_sp, isotropic_crop_resize, mux_audio_from_video, save_to_video, wan_vae_to_comfy
18+
from lightx2v.utils.utils import get_optimal_patched_size_with_sp, isotropic_crop_resize, mux_audio_from_video, save_to_image, save_to_video, wan_vae_to_comfy
1919
from lightx2v_platform.base.global_var import AI_DEVICE
2020

2121
torch_device_module = getattr(torch, AI_DEVICE)
@@ -457,16 +457,26 @@ def process_images_after_vae_decoder(self):
457457
fps = self.config.get("fps", 16)
458458

459459
if not dist.is_initialized() or dist.get_rank() == 0:
460-
logger.info(f"🎬 Start to save video 🎬")
461-
462-
save_to_video(self.gen_video_final, self.input_info.save_result_path, fps=fps, method="ffmpeg")
463-
if self.config.get("task") == "sr":
464-
input_video_path = getattr(self.input_info, "video_path", "")
465-
if input_video_path:
466-
muxed_path = mux_audio_from_video(input_video_path, self.input_info.save_result_path)
467-
if muxed_path:
468-
logger.info(f"Audio muxed from input video: {input_video_path}")
469-
logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅")
460+
out_path = self.input_info.save_result_path
461+
img_in = (getattr(self.input_info, "image_path", None) or "").strip()
462+
vid_in = (getattr(self.input_info, "video_path", None) or "").strip()
463+
sr_from_image_only = self.config.get("task") == "sr" and bool(img_in) and not bool(vid_in)
464+
465+
if sr_from_image_only:
466+
logger.info("🖼 Start to save SR image (image_path input, no video_path) 🖼")
467+
save_to_image(self.gen_video_final, out_path)
468+
logger.info(f"✅ Image saved successfully to: {out_path} ✅")
469+
else:
470+
logger.info(f"🎬 Start to save video 🎬")
471+
472+
save_to_video(self.gen_video_final, out_path, fps=fps, method="ffmpeg")
473+
if self.config.get("task") == "sr":
474+
input_video_path = getattr(self.input_info, "video_path", "")
475+
if input_video_path:
476+
muxed_path = mux_audio_from_video(input_video_path, out_path)
477+
if muxed_path:
478+
logger.info(f"Audio muxed from input video: {input_video_path}")
479+
logger.info(f"✅ Video saved successfully to: {out_path} ✅")
470480
return {"video": None}
471481

472482
@ProfilingContext4DebugL1("RUN pipeline", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_worker_request_duration, metrics_labels=["DefaultRunner"])

lightx2v/models/runners/ltx2/ltx2_runner.py

Lines changed: 119 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,49 @@
2121
torch_device_module = getattr(torch, AI_DEVICE)
2222

2323

24+
def _ltx2_parse_image_paths(image_path: str) -> list[str]:
25+
return [p.strip() for p in image_path.split(",") if p.strip()]
26+
27+
28+
def _ltx2_normalize_image_strengths(image_strength, n: int) -> list[float]:
29+
if not isinstance(image_strength, list):
30+
return [float(image_strength)] * n
31+
if len(image_strength) == 1:
32+
return [float(image_strength[0])] * n
33+
if len(image_strength) != n:
34+
raise ValueError(f"i2av image_strength: expected 1 or {n} values (scalar or list), got length {len(image_strength)}")
35+
return [float(x) for x in image_strength]
36+
37+
38+
def _ltx2_resolve_pixel_frame_indices(image_frame_idx, n: int, num_frames: int) -> list[int]:
39+
if not image_frame_idx:
40+
if n == 1:
41+
return [0]
42+
if num_frames <= 1:
43+
return [0] * n
44+
return [round(i * (num_frames - 1) / (n - 1)) for i in range(n)]
45+
if len(image_frame_idx) != n:
46+
raise ValueError(f"i2av image_frame_idx: expected {n} indices (one per image), got {len(image_frame_idx)}")
47+
hi = num_frames - 1
48+
return [max(0, min(hi, int(x))) for x in image_frame_idx]
49+
50+
51+
def _ltx2_pixel_to_latent_frame_idx(pixel_frame_idx: int, temporal_scale: int) -> int:
52+
if pixel_frame_idx == 0:
53+
return 0
54+
return (pixel_frame_idx - 1) // temporal_scale + 1
55+
56+
57+
def _ltx2_resize_video_denoise_mask_for_stage2(mask: torch.Tensor, target_h: int, target_w: int) -> torch.Tensor:
58+
"""Resize stage-1 unpatchified video denoise mask to stage-2 latent spatial size."""
59+
# mask shape: [1, F, H, W] -> [F, 1, H, W] for 2D interpolation
60+
m = mask.to(dtype=torch.float32)
61+
m = m.permute(1, 0, 2, 3)
62+
m = torch.nn.functional.interpolate(m, size=(target_h, target_w), mode="nearest")
63+
# back to [1, F, H, W]
64+
return m.permute(1, 0, 2, 3).contiguous()
65+
66+
2467
@RUNNER_REGISTER("ltx2")
2568
class LTX2Runner(DefaultRunner):
2669
def __init__(self, config):
@@ -146,14 +189,15 @@ def get_latent_shape_with_target_hw(self):
146189
target_width = self.config["target_width"]
147190
self.input_info.target_shape = [target_height, target_width]
148191

192+
target_video_length = self.input_info.target_video_length or self.config["target_video_length"]
149193
video_latent_shape = (
150194
self.config.get("num_channels_latents", 128),
151-
(self.config["target_video_length"] - 1) // self.config["vae_scale_factors"][0] + 1,
195+
(target_video_length - 1) // self.config["vae_scale_factors"][0] + 1,
152196
int(target_height) // self.config["vae_scale_factors"][1],
153197
int(target_width) // self.config["vae_scale_factors"][2],
154198
)
155199

156-
duration = float(self.config["target_video_length"]) / float(self.config["fps"])
200+
duration = float(target_video_length) / float(self.config["fps"])
157201
latents_per_second = float(self.config["audio_sampling_rate"]) / float(self.config["audio_hop_length"]) / float(self.config["audio_scale_factor"])
158202
audio_frames = round(duration * latents_per_second)
159203

@@ -178,8 +222,26 @@ def _run_input_encoder_local_t2av(self):
178222
"image_encoder_output": None,
179223
}
180224

225+
def _normalize_i2av_input_fields(self) -> None:
226+
info = self.input_info
227+
if isinstance(info.image_strength, str):
228+
p = [float(x.strip()) for x in info.image_strength.split(",") if x.strip()]
229+
info.image_strength = 1.0 if not p else (p[0] if len(p) == 1 else p)
230+
if isinstance(info.image_frame_idx, str):
231+
p = [int(x.strip()) for x in info.image_frame_idx.split(",") if x.strip()]
232+
info.image_frame_idx = p or None
233+
n = len(_ltx2_parse_image_paths(info.image_path or ""))
234+
if n == 0:
235+
return
236+
st, fi = info.image_strength, info.image_frame_idx
237+
if isinstance(st, list) and len(st) not in (1, n):
238+
raise ValueError(f"i2av image_strength: need 1 or {n} values, got {len(st)}")
239+
if fi is not None and len(fi) != n:
240+
raise ValueError(f"i2av image_frame_idx: need {n} indices, got {len(fi)}")
241+
181242
@ProfilingContext4DebugL2("Run Encoders")
182243
def _run_input_encoder_local_i2av(self):
244+
self._normalize_i2av_input_fields()
183245
self.input_info.video_latent_shape, self.input_info.audio_latent_shape = self.get_latent_shape_with_target_hw()
184246
text_encoder_output = self.run_text_encoder(self.input_info)
185247
self.video_denoise_mask, self.initial_video_latent = self.run_vae_encoder()
@@ -231,14 +293,27 @@ def run_vae_encoder(self):
231293
device=AI_DEVICE,
232294
)
233295

234-
# Process each image conditioning
235-
image_paths = self.input_info.image_path.split(",") # image_path1,image_path2,image_path3
236-
for frame_idx, image_path in enumerate(image_paths):
237-
if not isinstance(self.input_info.image_strength, list):
238-
strength = self.input_info.image_strength
239-
else:
240-
strength = self.input_info.image_strength[frame_idx]
241-
logger.info(f" 📷 Loading image: {image_path} for frame {frame_idx} with strength {strength}")
296+
image_paths = _ltx2_parse_image_paths(self.input_info.image_path)
297+
n = len(image_paths)
298+
if n == 0:
299+
logger.warning("i2av: image_path is empty, skipping image conditioning")
300+
self._i2av_guiding_keyframe_meta = None
301+
torch_device_module.empty_cache()
302+
gc.collect()
303+
return video_denoise_mask, initial_video_latent
304+
305+
num_frames = self.input_info.target_video_length or self.config.get("target_video_length", 1)
306+
strengths = _ltx2_normalize_image_strengths(self.input_info.image_strength, n)
307+
raw_frame_idx = getattr(self.input_info, "image_frame_idx", None)
308+
pixel_frame_indices = _ltx2_resolve_pixel_frame_indices(raw_frame_idx, n, num_frames)
309+
temporal_scale = int(self.config["vae_scale_factors"][0])
310+
311+
guiding_keyframe_meta: list[tuple[str, int, float]] = []
312+
313+
for i, image_path in enumerate(image_paths):
314+
strength = strengths[i]
315+
pixel_frame_idx = pixel_frame_indices[i]
316+
logger.info(f" 📷 Loading image: {image_path} pixel_frame={pixel_frame_idx} strength={strength} ({i + 1}/{n})")
242317

243318
# Load and preprocess image
244319
image = load_image_conditioning(
@@ -254,18 +329,15 @@ def run_vae_encoder(self):
254329

255330
encoded_latent = encoded_latent.squeeze(0)
256331

257-
# Verify frame index is valid
258-
if frame_idx < 0 or frame_idx >= F:
259-
logger.warning(f"⚠️ Frame index {frame_idx} out of range [0, {F - 1}], skipping")
332+
# Pixel frame 0 → write into the latent time slot; other frames → guiding tokens appended in the scheduler.
333+
if pixel_frame_idx != 0:
334+
guiding_keyframe_meta.append((image_path, pixel_frame_idx, strength))
260335
continue
261336

262337
# Get the latent frame index by converting pixel frame to latent frame
263338
# For LTX2, temporal compression is 8x, so latent_frame_idx = (frame_idx - 1) // 8 + 1 for frame_idx > 0
264339
# or 0 for frame_idx == 0
265-
if frame_idx == 0:
266-
latent_frame_idx = 0
267-
else:
268-
latent_frame_idx = (frame_idx - 1) // self.config["vae_scale_factors"][0] + 1
340+
latent_frame_idx = _ltx2_pixel_to_latent_frame_idx(pixel_frame_idx, temporal_scale)
269341

270342
if latent_frame_idx >= F:
271343
logger.warning(f"⚠️ Latent frame index {latent_frame_idx} out of range [0, {F - 1}], skipping")
@@ -281,6 +353,7 @@ def run_vae_encoder(self):
281353
video_denoise_mask[:, latent_frame_idx, :, :] = 1.0 - strength
282354

283355
logger.info(f" ✓ Encoded image to latent frame {latent_frame_idx}")
356+
self._i2av_guiding_keyframe_meta = guiding_keyframe_meta
284357

285358
torch_device_module.empty_cache()
286359
gc.collect()
@@ -289,6 +362,26 @@ def run_vae_encoder(self):
289362

290363
return video_denoise_mask, initial_video_latent
291364

365+
def _build_i2av_video_guiding_latents(self):
366+
"""Encode guiding keyframe images at current target_shape for scheduler.append (stage 1 / 2)."""
367+
meta = getattr(self, "_i2av_guiding_keyframe_meta", None)
368+
if not meta:
369+
return None
370+
th, tw = self.input_info.target_shape[0], self.input_info.target_shape[1]
371+
out = []
372+
for path, pixel_idx, strength in meta:
373+
image = load_image_conditioning(
374+
image_path=path,
375+
height=th,
376+
width=tw,
377+
dtype=GET_DTYPE(),
378+
device=AI_DEVICE,
379+
)
380+
with torch.no_grad():
381+
enc = self.video_vae.encode(image).squeeze(0)
382+
out.append((enc, pixel_idx, strength))
383+
return out
384+
292385
@ProfilingContext4DebugL1(
293386
"Run Text Encoder",
294387
recorder_mode=GET_RECORDER_MODE(),
@@ -365,12 +458,16 @@ def run_upsampler(self, v_latent, a_latent):
365458

366459
self.input_info.target_shape = [self.input_info.target_shape[0] * 2, self.input_info.target_shape[1] * 2]
367460
self.input_info.video_latent_shape, self.input_info.audio_latent_shape = self.get_latent_shape_with_target_hw()
461+
_, _, stage2_h, stage2_w = self.input_info.video_latent_shape
462+
stage2_video_denoise_mask = None
463+
if hasattr(self, "video_denoise_mask") and self.video_denoise_mask is not None:
464+
stage2_video_denoise_mask = _ltx2_resize_video_denoise_mask_for_stage2(self.video_denoise_mask, stage2_h, stage2_w)
368465

369466
# Prepare scheduler using the shared method
370467
self._prepare_scheduler(
371468
initial_video_latent=upsampled_v_latent, # Use upsampled video latent
372469
initial_audio_latent=a_latent, # Keep audio from stage 1 (aligned with distilled.py:183)
373-
video_denoise_mask=None, # Stage 2 fully denoises, no mask needed
470+
video_denoise_mask=stage2_video_denoise_mask, # Keep keyframe constraints in stage 2
374471
noise_scale=upsample_distilled_sigmas[0].item(), # Use first sigma as noise_scale (aligned with distilled.py:181)
375472
)
376473

@@ -425,6 +522,10 @@ def _prepare_scheduler(
425522
if noise_scale is not None:
426523
prepare_kwargs["noise_scale"] = noise_scale
427524

525+
vg = self._build_i2av_video_guiding_latents()
526+
if vg:
527+
prepare_kwargs["video_guiding_latents"] = vg
528+
428529
self.model.scheduler.prepare(**prepare_kwargs)
429530

430531
def init_run(self):

0 commit comments

Comments
 (0)