Skip to content

Commit c2aadca

Browse files
committed
Update
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 0d93e1a commit c2aadca

4 files changed

Lines changed: 243 additions & 57 deletions

File tree

examples/diffusers/quantization/quantize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,8 @@ def restore_checkpoint(self, backbone: nn.Module) -> None:
829829
mto.restore(backbone, str(self.config.restore_from))
830830
self.logger.info("Model restored successfully")
831831

832-
def export_hf_ckpt(self, pipe: DiffusionPipeline) -> None:
832+
# TODO: should not do the any data type
833+
def export_hf_ckpt(self, pipe: Any) -> None:
833834
"""
834835
Export quantized model to HuggingFace checkpoint format.
835836

modelopt/torch/export/diffusers_utils.py

Lines changed: 158 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,26 @@
1616
"""Code that export quantized Hugging Face models for deployment."""
1717

1818
import warnings
19+
from collections.abc import Callable
1920
from contextlib import contextmanager
2021
from importlib import import_module
2122
from typing import Any
2223

2324
import torch
2425
import torch.nn as nn
25-
from diffusers import DiffusionPipeline
2626

2727
from .layer_utils import is_quantlinear
2828

29+
DiffusionPipeline: type[Any] | None
30+
try: # diffusers is optional for LTX-2 export paths
31+
from diffusers import DiffusionPipeline as _DiffusionPipeline
32+
33+
DiffusionPipeline = _DiffusionPipeline
34+
_HAS_DIFFUSERS = True
35+
except Exception: # pragma: no cover
36+
DiffusionPipeline = None
37+
_HAS_DIFFUSERS = False
38+
2939

3040
def generate_diffusion_dummy_inputs(
3141
model: nn.Module, device: torch.device, dtype: torch.dtype
@@ -288,6 +298,126 @@ def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None:
288298
return None
289299

290300

301+
def generate_diffusion_dummy_forward_fn(model: nn.Module) -> Callable[[], None]:
302+
"""Create a dummy forward function for diffusion(-like) models.
303+
304+
- For diffusers components, this uses `generate_diffusion_dummy_inputs()` and calls `model(**kwargs)`.
305+
- For LTX-2 stage-1 transformer (X0Model), the forward signature is
306+
`model(video: Modality|None, audio: Modality|None, perturbations: BatchedPerturbationConfig)`,
307+
so we build tiny `ltx_core` dataclasses and call the model directly.
308+
"""
309+
# Duck-typed LTX-2 stage-1 transformer wrapper
310+
velocity_model = getattr(model, "velocity_model", None)
311+
if velocity_model is not None:
312+
313+
def _ltx2_dummy_forward() -> None:
314+
try:
315+
from ltx_core.guidance.perturbations import BatchedPerturbationConfig
316+
from ltx_core.model.transformer.modality import Modality
317+
except Exception as e: # pragma: no cover
318+
raise RuntimeError(
319+
"LTX-2 export requires `ltx_core` to be installed (Modality, BatchedPerturbationConfig)."
320+
) from e
321+
322+
# Small shapes for speed/memory
323+
batch_size = 1
324+
v_seq_len = 8
325+
a_seq_len = 8
326+
ctx_len = 4
327+
328+
device = next(model.parameters()).device
329+
default_dtype = next(model.parameters()).dtype
330+
331+
def _param_dtype(module: Any, fallback: torch.dtype) -> torch.dtype:
332+
w = getattr(getattr(module, "weight", None), "dtype", None)
333+
return w if isinstance(w, torch.dtype) else fallback
334+
335+
def _positions(bounds_dims: int, seq_len: int) -> torch.Tensor:
336+
# [B, dims, seq_len, 2] bounds (start/end)
337+
pos = torch.zeros(
338+
(batch_size, bounds_dims, seq_len, 2), device=device, dtype=torch.float32
339+
)
340+
pos[..., 1] = 1.0
341+
return pos
342+
343+
has_video = hasattr(velocity_model, "patchify_proj") and hasattr(
344+
velocity_model, "caption_projection"
345+
)
346+
has_audio = hasattr(velocity_model, "audio_patchify_proj") and hasattr(
347+
velocity_model, "audio_caption_projection"
348+
)
349+
if not has_video and not has_audio:
350+
raise ValueError(
351+
"Unsupported LTX-2 velocity model: missing both video and audio preprocessors."
352+
)
353+
354+
video = None
355+
if has_video:
356+
v_in = int(velocity_model.patchify_proj.in_features)
357+
v_caption_in = int(velocity_model.caption_projection.linear_1.in_features)
358+
v_latent_dtype = _param_dtype(velocity_model.patchify_proj, default_dtype)
359+
v_ctx_dtype = _param_dtype(
360+
velocity_model.caption_projection.linear_1, default_dtype
361+
)
362+
video = Modality(
363+
enabled=True,
364+
latent=torch.randn(
365+
batch_size, v_seq_len, v_in, device=device, dtype=v_latent_dtype
366+
),
367+
# LTX `X0Model` uses `timesteps` as the sigma tensor in `to_denoised(sample, velocity, sigma)`.
368+
# It must be broadcastable to `[B, T, D]`, so we use `[B, T, 1]`.
369+
timesteps=torch.full(
370+
(batch_size, v_seq_len, 1), 0.5, device=device, dtype=torch.float32
371+
),
372+
positions=_positions(bounds_dims=3, seq_len=v_seq_len),
373+
context=torch.randn(
374+
batch_size, ctx_len, v_caption_in, device=device, dtype=v_ctx_dtype
375+
),
376+
context_mask=None,
377+
)
378+
379+
audio = None
380+
if has_audio:
381+
a_in = int(velocity_model.audio_patchify_proj.in_features)
382+
a_caption_in = int(velocity_model.audio_caption_projection.linear_1.in_features)
383+
a_latent_dtype = _param_dtype(velocity_model.audio_patchify_proj, default_dtype)
384+
a_ctx_dtype = _param_dtype(
385+
velocity_model.audio_caption_projection.linear_1, default_dtype
386+
)
387+
audio = Modality(
388+
enabled=True,
389+
latent=torch.randn(
390+
batch_size, a_seq_len, a_in, device=device, dtype=a_latent_dtype
391+
),
392+
timesteps=torch.full(
393+
(batch_size, a_seq_len, 1), 0.5, device=device, dtype=torch.float32
394+
),
395+
positions=_positions(bounds_dims=1, seq_len=a_seq_len),
396+
context=torch.randn(
397+
batch_size, ctx_len, a_caption_in, device=device, dtype=a_ctx_dtype
398+
),
399+
context_mask=None,
400+
)
401+
402+
perturbations = BatchedPerturbationConfig.empty(batch_size)
403+
model(video, audio, perturbations)
404+
405+
return _ltx2_dummy_forward
406+
407+
# Default: diffusers-style `model(**kwargs)`
408+
def _diffusers_dummy_forward() -> None:
409+
device = next(model.parameters()).device
410+
dtype = next(model.parameters()).dtype
411+
dummy_inputs = generate_diffusion_dummy_inputs(model, device, dtype)
412+
if dummy_inputs is None:
413+
raise ValueError(
414+
f"Unknown model type '{type(model).__name__}', cannot generate dummy inputs."
415+
)
416+
model(**dummy_inputs)
417+
418+
return _diffusers_dummy_forward
419+
420+
291421
def is_qkv_projection(module_name: str) -> bool:
292422
"""Check if a module name corresponds to a QKV projection layer.
293423
@@ -377,25 +507,41 @@ def get_qkv_group_key(module_name: str) -> str:
377507
return f"{parent_path}.{qkv_type}"
378508

379509

380-
def get_diffusers_components(
381-
model: DiffusionPipeline | nn.Module,
510+
def get_diffusion_components(
511+
model: Any,
382512
components: list[str] | None = None,
383513
) -> dict[str, Any]:
384-
"""Get all exportable components from a diffusers pipeline.
514+
"""Get all exportable components from a diffusion(-like) pipeline.
385515
386-
This function extracts all components from a DiffusionPipeline including
387-
nn.Module models, tokenizers, schedulers, feature extractors, etc.
516+
Supports:
517+
- diffusers `DiffusionPipeline`: returns `pipeline.components`
518+
- diffusers component `nn.Module` (e.g., UNet / transformer)
519+
- LTX-2 pipeline (duck-typed): returns stage-1 transformer only as `stage_1_transformer`
388520
389521
Args:
390-
model: The diffusers pipeline.
522+
model: The pipeline or component.
391523
components: Optional list of component names to filter. If None, all
392524
components are returned.
393525
394526
Returns:
395527
Dictionary mapping component names to their instances (can be nn.Module,
396528
tokenizers, schedulers, etc.).
397529
"""
398-
if isinstance(model, DiffusionPipeline):
530+
# LTX-2 pipeline: duck-typed stage-1 transformer export
531+
stage_1 = getattr(model, "stage_1_model_ledger", None)
532+
transformer_fn = getattr(stage_1, "transformer", None)
533+
if stage_1 is not None and callable(transformer_fn):
534+
all_components: dict[str, Any] = {"stage_1_transformer": stage_1.transformer()}
535+
if components is not None:
536+
filtered = {name: comp for name, comp in all_components.items() if name in components}
537+
missing = set(components) - set(filtered.keys())
538+
if missing:
539+
warnings.warn(f"Requested components not found in pipeline: {missing}")
540+
return filtered
541+
return all_components
542+
543+
# diffusers pipeline
544+
if _HAS_DIFFUSERS and DiffusionPipeline is not None and isinstance(model, DiffusionPipeline):
399545
# Get all components from the pipeline
400546
all_components = {name: comp for name, comp in model.components.items() if comp is not None}
401547

@@ -427,6 +573,10 @@ def get_diffusers_components(
427573
raise TypeError(f"Expected DiffusionPipeline or nn.Module, got {type(model).__name__}")
428574

429575

576+
# Backward-compatible alias
577+
get_diffusers_components = get_diffusion_components
578+
579+
430580
@contextmanager
431581
def hide_quantizers_from_state_dict(model: nn.Module):
432582
"""Context manager that temporarily removes quantizer modules from the model.

0 commit comments

Comments
 (0)