Skip to content

Commit 26cd283

Browse files
committed
support ltx2 a2v
1 parent 8b0c7fc commit 26cd283

3 files changed

Lines changed: 83 additions & 38 deletions

File tree

diffsynth/pipelines/ltx2_audio_video.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
6767
LTX2AudioVideoUnit_SwitchStage2(),
6868
LTX2AudioVideoUnit_NoiseInitializer(),
6969
LTX2AudioVideoUnit_LatentsUpsampler(),
70-
LTX2AudioVideoUnit_SetScheduleStage2(),
7170
LTX2AudioVideoUnit_InputImagesEmbedder(),
71+
LTX2AudioVideoUnit_InputAudioEmbedder(),
72+
LTX2AudioVideoUnit_SetScheduleStage2(),
7273
]
7374
self.model_fn = model_fn_ltx2
7475

@@ -155,8 +156,9 @@ def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scal
155156
**models, timestep=timestep, progress_id=progress_id
156157
)
157158
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
158-
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
159-
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
159+
inpaint_mask=inputs_shared.get("video_denoise_mask", None), input_latents=inputs_shared.get("video_input_latents", None), **inputs_shared)
160+
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio,
161+
inpaint_mask=inputs_shared.get("audio_denoise_mask", None), input_latents=inputs_shared.get("audio_input_latents", None), **inputs_shared)
160162
return inputs_shared, inputs_posi, inputs_nega
161163

162164
@torch.no_grad()
@@ -173,6 +175,9 @@ def __call__(
173175
# In-Context Video Control
174176
in_context_videos: Optional[list[list[Image.Image]]] = None,
175177
in_context_downsample_factor: Optional[int] = 2,
178+
# Audio-to-video
179+
input_audio: Optional[torch.Tensor] = None,
180+
audio_sample_rate: Optional[int] = 48000,
176181
# Randomness
177182
seed: Optional[int] = None,
178183
rand_device: Optional[str] = "cpu",
@@ -210,6 +215,7 @@ def __call__(
210215
}
211216
inputs_shared = {
212217
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
218+
"input_audio": (input_audio, audio_sample_rate) if input_audio is not None else None,
213219
"in_context_videos": in_context_videos, "in_context_downsample_factor": in_context_downsample_factor,
214220
"seed": seed, "rand_device": rand_device,
215221
"height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
@@ -361,7 +367,8 @@ def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled,
361367
input_video = pipe.preprocess_video(input_video)
362368
input_latents = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
363369
if pipe.scheduler.training:
364-
return {"video_latents": input_latents, "input_latents": input_latents}
370+
# input_latents key is for training to add noise. with no prefix "video" to keep loss function keyword consistent.
371+
return {"video_latents": video_noise, "input_latents": input_latents}
365372
else:
366373
raise NotImplementedError("Video-to-video not implemented yet.")
367374

@@ -370,7 +377,7 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
370377
def __init__(self):
371378
super().__init__(
372379
input_params=("input_audio", "audio_noise"),
373-
output_params=("audio_latents", "audio_input_latents", "audio_positions", "audio_latent_shape"),
380+
output_params=("audio_latents", "audio_input_latents", "audio_noise", "audio_positions", "audio_latent_shape", "audio_denoise_mask"),
374381
onload_model_names=("audio_vae_encoder",)
375382
)
376383

@@ -380,21 +387,37 @@ def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise):
380387
else:
381388
input_audio, sample_rate = input_audio
382389
pipe.load_models_to_device(self.onload_model_names)
383-
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype)
390+
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype, device=pipe.device)
384391
audio_input_latents = pipe.audio_vae_encoder(input_audio)
385392
audio_latent_shape = AudioLatentShape.from_torch_shape(audio_input_latents.shape)
386393
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
387394
if pipe.scheduler.training:
388-
return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape}
395+
return {
396+
"audio_latents": audio_input_latents,
397+
"audio_input_latents": audio_input_latents,
398+
"audio_positions": audio_positions,
399+
"audio_latent_shape": audio_latent_shape,
400+
}
389401
else:
390-
raise NotImplementedError("Audio-to-video not supported.")
402+
b, c, t, f = audio_input_latents.shape
403+
audio_denoise_mask = torch.zeros((b, 1, t, 1), device=audio_input_latents.device, dtype=audio_input_latents.dtype)
404+
audio_noise = torch.rand_like(audio_input_latents)
405+
audio_latents = pipe.scheduler.add_noise(audio_input_latents, audio_noise, pipe.scheduler.timesteps[0])
406+
return {
407+
"audio_latents": audio_latents,
408+
"audio_input_latents": audio_input_latents,
409+
"audio_noise": audio_noise,
410+
"audio_positions": audio_positions,
411+
"audio_latent_shape": audio_latent_shape,
412+
"audio_denoise_mask": audio_denoise_mask,
413+
}
391414

392415

393416
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
394417
def __init__(self):
395418
super().__init__(
396419
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "frame_rate", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"),
397-
output_params=("denoise_mask_video", "input_latents_video", "ref_frames_latents", "ref_frames_positions"),
420+
output_params=("video_denoise_mask", "video_input_latents", "ref_frames_latents", "ref_frames_positions"),
398421
onload_model_names=("video_vae_encoder")
399422
)
400423

@@ -406,9 +429,9 @@ def get_image_latent(self, pipe, input_image, height, width, tiled, tile_size_in
406429
latents = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
407430
return latents
408431

409-
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, initial_latents=None, denoise_mask_video=None):
432+
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, initial_latents=None, video_denoise_mask=None):
410433
b, _, f, h, w = latents.shape
411-
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video
434+
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if video_denoise_mask is None else video_denoise_mask
412435
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
413436
for idx, input_latent in zip(input_indexes, input_latents):
414437
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
@@ -424,13 +447,13 @@ def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, hei
424447
if len(input_images_indexes) != len(set(input_images_indexes)):
425448
raise ValueError("Input images must have unique indexes.")
426449
pipe.load_models_to_device(self.onload_model_names)
427-
frame_conditions = {"input_latents_video": None, "denoise_mask_video": None, "ref_frames_latents": [], "ref_frames_positions": []}
450+
frame_conditions = {"video_input_latents": None, "video_denoise_mask": None, "ref_frames_latents": [], "ref_frames_positions": []}
428451
for img, index in zip(input_images, input_images_indexes):
429452
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
430453
# first_frame by replacing latents
431454
if index == 0:
432-
input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength, initial_latents)
433-
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
455+
video_input_latents, video_denoise_mask = self.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength, initial_latents)
456+
frame_conditions.update({"video_input_latents": video_input_latents, "video_denoise_mask": video_denoise_mask})
434457
# other frames by adding reference latents
435458
else:
436459
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(latents.shape), device=pipe.device)
@@ -560,14 +583,17 @@ def model_fn_ltx2(
560583
audio_patchifier=None,
561584
timestep=None,
562585
# First Frame Conditioning
563-
input_latents_video=None,
564-
denoise_mask_video=None,
586+
video_input_latents=None,
587+
video_denoise_mask=None,
565588
# Other Frames Conditioning
566589
ref_frames_latents=None,
567590
ref_frames_positions=None,
568591
# In-Context Conditioning
569592
in_context_video_latents=None,
570593
in_context_video_positions=None,
594+
# Audio Inputs
595+
audio_input_latents=None,
596+
audio_denoise_mask=None,
571597
# Gradient Checkpointing
572598
use_gradient_checkpointing=False,
573599
use_gradient_checkpointing_offload=False,
@@ -581,12 +607,12 @@ def model_fn_ltx2(
581607
seq_len_video = video_latents.shape[1]
582608
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
583609
# Frist frame conditioning by replacing the video latents
584-
if input_latents_video is not None:
585-
denoise_mask_video = video_patchifier.patchify(denoise_mask_video)
586-
video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video)
587-
video_timesteps = denoise_mask_video * video_timesteps
588-
589-
# Conditioning by replacing the video latents
610+
if video_input_latents is not None:
611+
video_denoise_mask = video_patchifier.patchify(video_denoise_mask)
612+
video_latents = video_latents * video_denoise_mask + video_patchifier.patchify(video_input_latents) * (1.0 - video_denoise_mask)
613+
video_timesteps = video_denoise_mask * video_timesteps
614+
615+
# Reference conditioning by appending the reference video or frame latents
590616
total_ref_latents = ref_frames_latents if ref_frames_latents is not None else []
591617
total_ref_positions = ref_frames_positions if ref_frames_positions is not None else []
592618
total_ref_latents += [in_context_video_latents] if in_context_video_latents is not None else []
@@ -605,6 +631,10 @@ def model_fn_ltx2(
605631
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
606632
else:
607633
audio_timesteps = None
634+
if audio_input_latents is not None:
635+
audio_denoise_mask = audio_patchifier.patchify(audio_denoise_mask)
636+
audio_latents = audio_latents * audio_denoise_mask + audio_patchifier.patchify(audio_input_latents) * (1.0 - audio_denoise_mask)
637+
audio_timesteps = audio_denoise_mask * audio_timesteps
608638

609639
vx, ax = dit(
610640
video_latents=video_latents,

diffsynth/utils/data/media_io_ltx2.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
21
from fractions import Fraction
32
import torch
3+
import torchaudio
44
import av
55
from tqdm import tqdm
66
from PIL import Image
77
import numpy as np
88
from io import BytesIO
9-
from collections.abc import Generator, Iterator
10-
import torchaudio
119

1210

1311
def _resample_audio(
@@ -70,9 +68,9 @@ def _prepare_audio_stream(container: av.container.Container, audio_sample_rate:
7068
audio_stream = container.add_stream("aac")
7169
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
7270
if supported_sample_rates:
73-
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
74-
if best_rate != audio_sample_rate:
75-
print(f"Using closest supported audio sample rate: {best_rate}")
71+
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
72+
if best_rate != audio_sample_rate:
73+
print(f"Using closest supported audio sample rate: {best_rate}")
7674
else:
7775
best_rate = audio_sample_rate
7876
audio_stream.codec_context.sample_rate = best_rate
@@ -117,7 +115,7 @@ def write_video_audio_ltx2(
117115
stream.width = width
118116
stream.height = height
119117
stream.pix_fmt = "yuv420p"
120-
118+
121119
if audio is not None:
122120
if audio_sample_rate is None:
123121
raise ValueError("audio_sample_rate is required when audio is provided")
@@ -138,13 +136,24 @@ def write_video_audio_ltx2(
138136
container.close()
139137

140138

141-
def read_audio_with_torchaudio(path: str, start_time: float = 0, duration: float | None = None) -> torch.Tensor:
139+
def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor:
140+
"""Resample waveform to target sample rate if needed."""
141+
if source_rate == target_rate:
142+
return waveform
143+
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
144+
return resampled.to(dtype=waveform.dtype)
145+
146+
147+
def read_audio_with_torchaudio(path: str, start_time: float = 0, duration: float | None = None, resample: bool = False, resample_rate: int = 48000) -> torch.Tensor:
142148
waveform, sample_rate = torchaudio.load(path, channels_first=True)
149+
if resample:
150+
waveform = resample_waveform(waveform, sample_rate, resample_rate)
151+
sample_rate = resample_rate
143152
start_frame = int(start_time * sample_rate)
144153
if start_frame > waveform.shape[-1]:
145154
raise ValueError(f"start_time of {start_time} exceeds max duration of {waveform.shape[-1] / sample_rate:.2f}")
146-
end_frame = -1 if duration is None else int(duration * sample_rate)
147-
return waveform[..., start_frame:end_frame]
155+
end_frame = -1 if duration is None else int(duration * sample_rate + start_frame)
156+
return waveform[..., start_frame:end_frame], sample_rate
148157

149158

150159
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:

examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import torch
22
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
33
from diffsynth.utils.data.media_io_ltx2 import read_audio_with_torchaudio, write_video_audio_ltx2
4-
5-
audio = read_audio_with_torchaudio("data/example_video_dataset/ltx2/sing.MP3")
4+
from modelscope import dataset_snapshot_download
65

76
vram_config = {
87
"offload_dtype": torch.bfloat16,
@@ -25,7 +24,9 @@
2524
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
2625
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
2726
)
28-
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
27+
28+
dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="ltx2/*", local_dir="data/example_video_dataset")
29+
prompt = "A beautiful woman with a flower crown is singing happily under a blooming cherry tree."
2930
negative_prompt = (
3031
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
3132
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
@@ -39,21 +40,26 @@
3940
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
4041
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
4142
)
42-
height, width, num_frames = 512 * 2, 768 * 2, 121
43+
height, width, num_frames, frame_rate = 512 * 2, 768 * 2, 121, 24
44+
duration = num_frames / frame_rate
45+
audio, audio_sample_rate = read_audio_with_torchaudio("data/example_video_dataset/ltx2/sing.MP3", start_time=1, duration=duration)
4346
video, audio = pipe(
4447
prompt=prompt,
4548
negative_prompt=negative_prompt,
49+
input_audio=audio,
50+
audio_sample_rate=audio_sample_rate,
4651
seed=43,
4752
height=height,
4853
width=width,
4954
num_frames=num_frames,
55+
frame_rate=frame_rate,
5056
tiled=True,
5157
use_two_stage_pipeline=True,
5258
)
5359
write_video_audio_ltx2(
5460
video=video,
5561
audio=audio,
56-
output_path='ltx2.3_twostage.mp4',
57-
fps=24,
62+
output_path='ltx2.3_twostage_a2v.mp4',
63+
fps=frame_rate,
5864
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
5965
)

0 commit comments

Comments
 (0)