Skip to content

Commit c6f72ad

Browse files
add ltx2 vae in sana-video; (#13229)
* add ltx2 vae in sana-video; * add ltx vae in conversion script; * Update src/diffusers/pipelines/sana_video/pipeline_sana_video.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/sana_video/pipeline_sana_video.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * condition `vae_scale_factor_xxx` related settings on VAE types; * make the mean/std depends on vae class; --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent 11a3284 commit c6f72ad

File tree

3 files changed

+115
-42
lines changed

3 files changed

+115
-42
lines changed

scripts/convert_sana_video_to_diffusers.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from transformers import AutoModelForCausalLM, AutoTokenizer
1313

1414
from diffusers import (
15+
AutoencoderKLLTX2Video,
1516
AutoencoderKLWan,
1617
DPMSolverMultistepScheduler,
1718
FlowMatchEulerDiscreteScheduler,
@@ -24,7 +25,10 @@
2425

2526
CTX = init_empty_weights if is_accelerate_available else nullcontext
2627

27-
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
28+
ckpt_ids = [
29+
"Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth",
30+
"Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth",
31+
]
2832
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
2933

3034

@@ -92,12 +96,22 @@ def main(args):
9296
if args.video_size == 480:
9397
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
9498
patch_size = (1, 2, 2)
99+
in_channels = 16
100+
out_channels = 16
95101
elif args.video_size == 720:
96-
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
102+
sample_size = 22 # DC-AE-V: 32xp1 downsample factor
97103
patch_size = (1, 1, 1)
104+
in_channels = 32
105+
out_channels = 32
98106
else:
99107
raise ValueError(f"Video size {args.video_size} is not supported.")
100108

109+
if args.vae_type == "ltx2":
110+
sample_size = 22
111+
patch_size = (1, 1, 1)
112+
in_channels = 128
113+
out_channels = 128
114+
101115
for depth in range(layer_num):
102116
# Transformer blocks.
103117
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
@@ -182,8 +196,8 @@ def main(args):
182196
# Transformer
183197
with CTX():
184198
transformer_kwargs = {
185-
"in_channels": 16,
186-
"out_channels": 16,
199+
"in_channels": in_channels,
200+
"out_channels": out_channels,
187201
"num_attention_heads": 20,
188202
"attention_head_dim": 112,
189203
"num_layers": 20,
@@ -235,9 +249,12 @@ def main(args):
235249
else:
236250
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
237251
# VAE
238-
vae = AutoencoderKLWan.from_pretrained(
239-
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
240-
)
252+
if args.vae_type == "ltx2":
253+
vae_path = args.vae_path or "Lightricks/LTX-2"
254+
vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
255+
else:
256+
vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
257+
vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
241258

242259
# Text Encoder
243260
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
@@ -314,7 +331,23 @@ def main(args):
314331
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
315332
help="Scheduler type to use.",
316333
)
317-
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
334+
parser.add_argument(
335+
"--vae_type",
336+
default="wan",
337+
type=str,
338+
choices=["wan", "ltx2"],
339+
help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).",
340+
)
341+
parser.add_argument(
342+
"--vae_path",
343+
default=None,
344+
type=str,
345+
required=False,
346+
help="Optional VAE path or repo id. If not set, a default is used per VAE type.",
347+
)
348+
parser.add_argument(
349+
"--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v."
350+
)
318351
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
319352
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
320353
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")

src/diffusers/pipelines/sana_video/pipeline_sana_video.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2626
from ...loaders import SanaLoraLoaderMixin
27-
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
27+
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
2828
from ...schedulers import DPMSolverMultistepScheduler
2929
from ...utils import (
3030
BACKENDS_MAPPING,
@@ -194,7 +194,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
194194
The tokenizer used to tokenize the prompt.
195195
text_encoder ([`Gemma2PreTrainedModel`]):
196196
Text encoder model to encode the input prompts.
197-
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
197+
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
198198
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
199199
transformer ([`SanaVideoTransformer3DModel`]):
200200
Conditional Transformer to denoise the input latents.
@@ -213,7 +213,7 @@ def __init__(
213213
self,
214214
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
215215
text_encoder: Gemma2PreTrainedModel,
216-
vae: AutoencoderDC | AutoencoderKLWan,
216+
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
217217
transformer: SanaVideoTransformer3DModel,
218218
scheduler: DPMSolverMultistepScheduler,
219219
):
@@ -223,8 +223,19 @@ def __init__(
223223
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
224224
)
225225

226-
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
227-
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
226+
if getattr(self, "vae", None):
227+
if isinstance(self.vae, AutoencoderKLLTX2Video):
228+
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
229+
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
230+
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
231+
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
232+
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
233+
else:
234+
self.vae_scale_factor_temporal = 4
235+
self.vae_scale_factor_spatial = 8
236+
else:
237+
self.vae_scale_factor_temporal = 4
238+
self.vae_scale_factor_spatial = 8
228239

229240
self.vae_scale_factor = self.vae_scale_factor_spatial
230241

@@ -985,14 +996,21 @@ def __call__(
985996
if is_torch_version(">=", "2.5.0")
986997
else torch_accelerator_module.OutOfMemoryError
987998
)
988-
latents_mean = (
989-
torch.tensor(self.vae.config.latents_mean)
990-
.view(1, self.vae.config.z_dim, 1, 1, 1)
991-
.to(latents.device, latents.dtype)
992-
)
993-
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
994-
latents.device, latents.dtype
995-
)
999+
if isinstance(self.vae, AutoencoderKLLTX2Video):
1000+
latents_mean = self.vae.latents_mean
1001+
latents_std = self.vae.latents_std
1002+
z_dim = self.vae.config.latent_channels
1003+
elif isinstance(self.vae, AutoencoderKLWan):
1004+
latents_mean = torch.tensor(self.vae.config.latents_mean)
1005+
latents_std = torch.tensor(self.vae.config.latents_std)
1006+
z_dim = self.vae.config.z_dim
1007+
else:
1008+
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
1009+
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
1010+
z_dim = latents.shape[1]
1011+
1012+
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
1013+
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
9961014
latents = latents / latents_std + latents_mean
9971015
try:
9981016
video = self.vae.decode(latents, return_dict=False)[0]

src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2727
from ...image_processor import PipelineImageInput
2828
from ...loaders import SanaLoraLoaderMixin
29-
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
29+
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
3030
from ...schedulers import FlowMatchEulerDiscreteScheduler
3131
from ...utils import (
3232
BACKENDS_MAPPING,
@@ -184,7 +184,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
184184
The tokenizer used to tokenize the prompt.
185185
text_encoder ([`Gemma2PreTrainedModel`]):
186186
Text encoder model to encode the input prompts.
187-
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
187+
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
188188
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
189189
transformer ([`SanaVideoTransformer3DModel`]):
190190
Conditional Transformer to denoise the input latents.
@@ -203,7 +203,7 @@ def __init__(
203203
self,
204204
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
205205
text_encoder: Gemma2PreTrainedModel,
206-
vae: AutoencoderDC | AutoencoderKLWan,
206+
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
207207
transformer: SanaVideoTransformer3DModel,
208208
scheduler: FlowMatchEulerDiscreteScheduler,
209209
):
@@ -213,8 +213,19 @@ def __init__(
213213
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
214214
)
215215

216-
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
217-
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
216+
if getattr(self, "vae", None):
217+
if isinstance(self.vae, AutoencoderKLLTX2Video):
218+
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
219+
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
220+
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
221+
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
222+
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
223+
else:
224+
self.vae_scale_factor_temporal = 4
225+
self.vae_scale_factor_spatial = 8
226+
else:
227+
self.vae_scale_factor_temporal = 4
228+
self.vae_scale_factor_spatial = 8
218229

219230
self.vae_scale_factor = self.vae_scale_factor_spatial
220231

@@ -687,14 +698,18 @@ def prepare_latents(
687698
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
688699
image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1)
689700

690-
latents_mean = (
691-
torch.tensor(self.vae.config.latents_mean)
692-
.view(1, -1, 1, 1, 1)
693-
.to(image_latents.device, image_latents.dtype)
694-
)
695-
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(
696-
image_latents.device, image_latents.dtype
697-
)
701+
if isinstance(self.vae, AutoencoderKLLTX2Video):
702+
_latents_mean = self.vae.latents_mean
703+
_latents_std = self.vae.latents_std
704+
elif isinstance(self.vae, AutoencoderKLWan):
705+
_latents_mean = torch.tensor(self.vae.config.latents_mean)
706+
_latents_std = torch.tensor(self.vae.config.latents_std)
707+
else:
708+
_latents_mean = torch.zeros(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
709+
_latents_std = torch.ones(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
710+
711+
latents_mean = _latents_mean.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
712+
latents_std = 1.0 / _latents_std.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
698713
image_latents = (image_latents - latents_mean) * latents_std
699714

700715
latents[:, :, 0:1] = image_latents.to(dtype)
@@ -1034,14 +1049,21 @@ def __call__(
10341049
if is_torch_version(">=", "2.5.0")
10351050
else torch_accelerator_module.OutOfMemoryError
10361051
)
1037-
latents_mean = (
1038-
torch.tensor(self.vae.config.latents_mean)
1039-
.view(1, self.vae.config.z_dim, 1, 1, 1)
1040-
.to(latents.device, latents.dtype)
1041-
)
1042-
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
1043-
latents.device, latents.dtype
1044-
)
1052+
if isinstance(self.vae, AutoencoderKLLTX2Video):
1053+
latents_mean = self.vae.latents_mean
1054+
latents_std = self.vae.latents_std
1055+
z_dim = self.vae.config.latent_channels
1056+
elif isinstance(self.vae, AutoencoderKLWan):
1057+
latents_mean = torch.tensor(self.vae.config.latents_mean)
1058+
latents_std = torch.tensor(self.vae.config.latents_std)
1059+
z_dim = self.vae.config.z_dim
1060+
else:
1061+
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
1062+
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
1063+
z_dim = latents.shape[1]
1064+
1065+
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
1066+
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
10451067
latents = latents / latents_std + latents_mean
10461068
try:
10471069
video = self.vae.decode(latents, return_dict=False)[0]

0 commit comments

Comments
 (0)