|
16 | 16 | """Code that export quantized Hugging Face models for deployment.""" |
17 | 17 |
|
18 | 18 | import warnings |
| 19 | +from collections.abc import Callable |
19 | 20 | from contextlib import contextmanager |
20 | 21 | from importlib import import_module |
21 | 22 | from typing import Any |
22 | 23 |
|
23 | 24 | import torch |
24 | 25 | import torch.nn as nn |
25 | | -from diffusers import DiffusionPipeline |
26 | 26 |
|
27 | 27 | from .layer_utils import is_quantlinear |
28 | 28 |
|
| 29 | +DiffusionPipeline: type[Any] | None |
| 30 | +try: # diffusers is optional for LTX-2 export paths |
| 31 | + from diffusers import DiffusionPipeline as _DiffusionPipeline |
| 32 | + |
| 33 | + DiffusionPipeline = _DiffusionPipeline |
| 34 | + _HAS_DIFFUSERS = True |
| 35 | +except Exception: # pragma: no cover |
| 36 | + DiffusionPipeline = None |
| 37 | + _HAS_DIFFUSERS = False |
| 38 | + |
29 | 39 |
|
30 | 40 | def generate_diffusion_dummy_inputs( |
31 | 41 | model: nn.Module, device: torch.device, dtype: torch.dtype |
@@ -288,6 +298,126 @@ def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None: |
288 | 298 | return None |
289 | 299 |
|
290 | 300 |
|
| 301 | +def generate_diffusion_dummy_forward_fn(model: nn.Module) -> Callable[[], None]: |
| 302 | + """Create a dummy forward function for diffusion(-like) models. |
| 303 | +
|
| 304 | + - For diffusers components, this uses `generate_diffusion_dummy_inputs()` and calls `model(**kwargs)`. |
| 305 | + - For LTX-2 stage-1 transformer (X0Model), the forward signature is |
| 306 | + `model(video: Modality|None, audio: Modality|None, perturbations: BatchedPerturbationConfig)`, |
| 307 | + so we build tiny `ltx_core` dataclasses and call the model directly. |
| 308 | + """ |
| 309 | + # Duck-typed LTX-2 stage-1 transformer wrapper |
| 310 | + velocity_model = getattr(model, "velocity_model", None) |
| 311 | + if velocity_model is not None: |
| 312 | + |
| 313 | + def _ltx2_dummy_forward() -> None: |
| 314 | + try: |
| 315 | + from ltx_core.guidance.perturbations import BatchedPerturbationConfig |
| 316 | + from ltx_core.model.transformer.modality import Modality |
| 317 | + except Exception as e: # pragma: no cover |
| 318 | + raise RuntimeError( |
| 319 | + "LTX-2 export requires `ltx_core` to be installed (Modality, BatchedPerturbationConfig)." |
| 320 | + ) from e |
| 321 | + |
| 322 | + # Small shapes for speed/memory |
| 323 | + batch_size = 1 |
| 324 | + v_seq_len = 8 |
| 325 | + a_seq_len = 8 |
| 326 | + ctx_len = 4 |
| 327 | + |
| 328 | + device = next(model.parameters()).device |
| 329 | + default_dtype = next(model.parameters()).dtype |
| 330 | + |
| 331 | + def _param_dtype(module: Any, fallback: torch.dtype) -> torch.dtype: |
| 332 | + w = getattr(getattr(module, "weight", None), "dtype", None) |
| 333 | + return w if isinstance(w, torch.dtype) else fallback |
| 334 | + |
| 335 | + def _positions(bounds_dims: int, seq_len: int) -> torch.Tensor: |
| 336 | + # [B, dims, seq_len, 2] bounds (start/end) |
| 337 | + pos = torch.zeros( |
| 338 | + (batch_size, bounds_dims, seq_len, 2), device=device, dtype=torch.float32 |
| 339 | + ) |
| 340 | + pos[..., 1] = 1.0 |
| 341 | + return pos |
| 342 | + |
| 343 | + has_video = hasattr(velocity_model, "patchify_proj") and hasattr( |
| 344 | + velocity_model, "caption_projection" |
| 345 | + ) |
| 346 | + has_audio = hasattr(velocity_model, "audio_patchify_proj") and hasattr( |
| 347 | + velocity_model, "audio_caption_projection" |
| 348 | + ) |
| 349 | + if not has_video and not has_audio: |
| 350 | + raise ValueError( |
| 351 | + "Unsupported LTX-2 velocity model: missing both video and audio preprocessors." |
| 352 | + ) |
| 353 | + |
| 354 | + video = None |
| 355 | + if has_video: |
| 356 | + v_in = int(velocity_model.patchify_proj.in_features) |
| 357 | + v_caption_in = int(velocity_model.caption_projection.linear_1.in_features) |
| 358 | + v_latent_dtype = _param_dtype(velocity_model.patchify_proj, default_dtype) |
| 359 | + v_ctx_dtype = _param_dtype( |
| 360 | + velocity_model.caption_projection.linear_1, default_dtype |
| 361 | + ) |
| 362 | + video = Modality( |
| 363 | + enabled=True, |
| 364 | + latent=torch.randn( |
| 365 | + batch_size, v_seq_len, v_in, device=device, dtype=v_latent_dtype |
| 366 | + ), |
| 367 | + # LTX `X0Model` uses `timesteps` as the sigma tensor in `to_denoised(sample, velocity, sigma)`. |
| 368 | + # It must be broadcastable to `[B, T, D]`, so we use `[B, T, 1]`. |
| 369 | + timesteps=torch.full( |
| 370 | + (batch_size, v_seq_len, 1), 0.5, device=device, dtype=torch.float32 |
| 371 | + ), |
| 372 | + positions=_positions(bounds_dims=3, seq_len=v_seq_len), |
| 373 | + context=torch.randn( |
| 374 | + batch_size, ctx_len, v_caption_in, device=device, dtype=v_ctx_dtype |
| 375 | + ), |
| 376 | + context_mask=None, |
| 377 | + ) |
| 378 | + |
| 379 | + audio = None |
| 380 | + if has_audio: |
| 381 | + a_in = int(velocity_model.audio_patchify_proj.in_features) |
| 382 | + a_caption_in = int(velocity_model.audio_caption_projection.linear_1.in_features) |
| 383 | + a_latent_dtype = _param_dtype(velocity_model.audio_patchify_proj, default_dtype) |
| 384 | + a_ctx_dtype = _param_dtype( |
| 385 | + velocity_model.audio_caption_projection.linear_1, default_dtype |
| 386 | + ) |
| 387 | + audio = Modality( |
| 388 | + enabled=True, |
| 389 | + latent=torch.randn( |
| 390 | + batch_size, a_seq_len, a_in, device=device, dtype=a_latent_dtype |
| 391 | + ), |
| 392 | + timesteps=torch.full( |
| 393 | + (batch_size, a_seq_len, 1), 0.5, device=device, dtype=torch.float32 |
| 394 | + ), |
| 395 | + positions=_positions(bounds_dims=1, seq_len=a_seq_len), |
| 396 | + context=torch.randn( |
| 397 | + batch_size, ctx_len, a_caption_in, device=device, dtype=a_ctx_dtype |
| 398 | + ), |
| 399 | + context_mask=None, |
| 400 | + ) |
| 401 | + |
| 402 | + perturbations = BatchedPerturbationConfig.empty(batch_size) |
| 403 | + model(video, audio, perturbations) |
| 404 | + |
| 405 | + return _ltx2_dummy_forward |
| 406 | + |
| 407 | + # Default: diffusers-style `model(**kwargs)` |
| 408 | + def _diffusers_dummy_forward() -> None: |
| 409 | + device = next(model.parameters()).device |
| 410 | + dtype = next(model.parameters()).dtype |
| 411 | + dummy_inputs = generate_diffusion_dummy_inputs(model, device, dtype) |
| 412 | + if dummy_inputs is None: |
| 413 | + raise ValueError( |
| 414 | + f"Unknown model type '{type(model).__name__}', cannot generate dummy inputs." |
| 415 | + ) |
| 416 | + model(**dummy_inputs) |
| 417 | + |
| 418 | + return _diffusers_dummy_forward |
| 419 | + |
| 420 | + |
291 | 421 | def is_qkv_projection(module_name: str) -> bool: |
292 | 422 | """Check if a module name corresponds to a QKV projection layer. |
293 | 423 |
|
@@ -377,25 +507,41 @@ def get_qkv_group_key(module_name: str) -> str: |
377 | 507 | return f"{parent_path}.{qkv_type}" |
378 | 508 |
|
379 | 509 |
|
380 | | -def get_diffusers_components( |
381 | | - model: DiffusionPipeline | nn.Module, |
| 510 | +def get_diffusion_components( |
| 511 | + model: Any, |
382 | 512 | components: list[str] | None = None, |
383 | 513 | ) -> dict[str, Any]: |
384 | | - """Get all exportable components from a diffusers pipeline. |
| 514 | + """Get all exportable components from a diffusion(-like) pipeline. |
385 | 515 |
|
386 | | - This function extracts all components from a DiffusionPipeline including |
387 | | - nn.Module models, tokenizers, schedulers, feature extractors, etc. |
| 516 | + Supports: |
| 517 | + - diffusers `DiffusionPipeline`: returns `pipeline.components` |
| 518 | + - diffusers component `nn.Module` (e.g., UNet / transformer) |
| 519 | + - LTX-2 pipeline (duck-typed): returns stage-1 transformer only as `stage_1_transformer` |
388 | 520 |
|
389 | 521 | Args: |
390 | | - model: The diffusers pipeline. |
| 522 | + model: The pipeline or component. |
391 | 523 | components: Optional list of component names to filter. If None, all |
392 | 524 | components are returned. |
393 | 525 |
|
394 | 526 | Returns: |
395 | 527 | Dictionary mapping component names to their instances (can be nn.Module, |
396 | 528 | tokenizers, schedulers, etc.). |
397 | 529 | """ |
398 | | - if isinstance(model, DiffusionPipeline): |
| 530 | + # LTX-2 pipeline: duck-typed stage-1 transformer export |
| 531 | + stage_1 = getattr(model, "stage_1_model_ledger", None) |
| 532 | + transformer_fn = getattr(stage_1, "transformer", None) |
| 533 | + if stage_1 is not None and callable(transformer_fn): |
| 534 | + all_components: dict[str, Any] = {"stage_1_transformer": stage_1.transformer()} |
| 535 | + if components is not None: |
| 536 | + filtered = {name: comp for name, comp in all_components.items() if name in components} |
| 537 | + missing = set(components) - set(filtered.keys()) |
| 538 | + if missing: |
| 539 | + warnings.warn(f"Requested components not found in pipeline: {missing}") |
| 540 | + return filtered |
| 541 | + return all_components |
| 542 | + |
| 543 | + # diffusers pipeline |
| 544 | + if _HAS_DIFFUSERS and DiffusionPipeline is not None and isinstance(model, DiffusionPipeline): |
399 | 545 | # Get all components from the pipeline |
400 | 546 | all_components = {name: comp for name, comp in model.components.items() if comp is not None} |
401 | 547 |
|
@@ -427,6 +573,10 @@ def get_diffusers_components( |
427 | 573 | raise TypeError(f"Expected DiffusionPipeline or nn.Module, got {type(model).__name__}") |
428 | 574 |
|
429 | 575 |
|
| 576 | +# Backward-compatible alias |
| 577 | +get_diffusers_components = get_diffusion_components |
| 578 | + |
| 579 | + |
430 | 580 | @contextmanager |
431 | 581 | def hide_quantizers_from_state_dict(model: nn.Module): |
432 | 582 | """Context manager that temporarily removes quantizer modules from the model. |
|
0 commit comments