Skip to content

Commit ca1f968

Browse files
authored
[OMNIML-3505] LTX-2 Distillation Trainer (#892)
## What does this PR do? **Type of change:** new example <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Adding LTX-2 distillation trainer. ## Usage <!-- You can potentially add a usage example below. --> ```bash accelerate launch \ --config_file configs/accelerate/fsdp.yaml \ --num_processes 8 \ distillation_trainer.py --config configs/distillation_example.yaml ``` See readme for more details. ## Testing Run training with single/multiple nodes. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: NA - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## New Features * Added distillation training support for LTX-2 models with quantization integration. * Introduced comprehensive documentation and example configurations for distillation workflows. * Includes multi-GPU and multi-node training setup with distributed training support and customizable configuration templates. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Meng Xin <mxin@nvidia.com>
1 parent 5c4ef8e commit ca1f968

6 files changed

Lines changed: 2177 additions & 0 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ NVIDIA Model Optimizer Changelog (Linux)
2222
- Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is.
2323
- Add support for image-text data calibration in PTQ for Nemotron VL models.
2424
- Add PTQ support for Nemotron Parse.
25+
- Add distillation support for LTX-2. See `examples/diffusers/distillation/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/diffusers/distillation>`_ for more details.
2526

2627
0.41 (2026-01-19)
2728
^^^^^^^^^^^^^^^^^
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# LTX-2 Distillation Training with ModelOpt
2+
3+
Knowledge distillation for LTX-2 DiT models using NVIDIA ModelOpt. A frozen **teacher** guides a trainable **student** through a combined loss:
4+
5+
```text
6+
L_total = α × L_task + (1-α) × L_distill
7+
```
8+
9+
Currently supported:
10+
11+
- **Quantization-Aware Distillation (QAD)** — student uses ModelOpt fake quantization
12+
13+
Planned:
14+
15+
- **Sparsity-Aware Distillation (SAD)** — student uses ModelOpt sparsity
16+
17+
## Installation
18+
19+
```bash
20+
# From the distillation example directory
21+
cd examples/diffusers/distillation
22+
23+
# Install Model-Optimizer (from repo root)
24+
pip install -e ../../..
25+
26+
# Install all dependencies (ltx-trainer, ltx-core, ltx-pipelines, omegaconf)
27+
pip install -r requirements.txt
28+
```
29+
30+
## Quick Start
31+
32+
### 1. Prepare Your Dataset
33+
34+
Use the ltx-trainer preprocessing to extract latents and text embeddings:
35+
36+
```bash
37+
python -m ltx_trainer.preprocess \
38+
--input_dir /path/to/videos \
39+
--output_dir /path/to/preprocessed \
40+
--model_path /path/to/ltx2/checkpoint.safetensors
41+
```
42+
43+
### 2. Configure
44+
45+
Copy and edit the example config:
46+
47+
```bash
48+
cp configs/distillation_example.yaml configs/my_experiment.yaml
49+
```
50+
51+
Key settings to update:
52+
53+
```yaml
54+
model:
55+
model_path: "/path/to/ltx2/checkpoint.safetensors"
56+
text_encoder_path: "/path/to/gemma/model"
57+
58+
data:
59+
preprocessed_data_root: "/path/to/preprocessed/data"
60+
61+
distillation:
62+
distillation_alpha: 0.5 # 1.0 = pure task loss, 0.0 = pure distillation
63+
quant_cfg: "FP8_DEFAULT_CFG" # or INT8_DEFAULT_CFG, NVFP4_DEFAULT_CFG, null
64+
65+
# IMPORTANT: disable ltx-trainer's built-in quantization
66+
acceleration:
67+
quantization: null
68+
```
69+
70+
### 3. Run Training
71+
72+
#### Single GPU
73+
74+
```bash
75+
python distillation_trainer.py --config configs/my_experiment.yaml
76+
```
77+
78+
#### Multi-GPU (Single Node) with Accelerate
79+
80+
```bash
81+
accelerate launch \
82+
--config_file configs/accelerate/fsdp.yaml \
83+
--num_processes 8 \
84+
distillation_trainer.py --config configs/my_experiment.yaml
85+
```
86+
87+
#### Multi-node Training with Accelerate
88+
89+
To launch on multiple nodes, make sure to set the following environment variables on each node:
90+
91+
- `NUM_NODES`: Total number of nodes
92+
- `GPUS_PER_NODE`: Number of GPUs per node
93+
- `NODE_RANK`: Unique rank/index of this node (0-based)
94+
- `MASTER_ADDR`: IP address of the master node (rank 0)
95+
- `MASTER_PORT`: Communication port (e.g., 29500)
96+
97+
Then run this (on every node):
98+
99+
```bash
100+
accelerate launch \
101+
--config_file configs/accelerate/fsdp.yaml \
102+
--num_machines $NUM_NODES \
103+
--num_processes $((NUM_NODES * GPUS_PER_NODE)) \
104+
--machine_rank $NODE_RANK \
105+
--main_process_ip $MASTER_ADDR \
106+
--main_process_port $MASTER_PORT \
107+
distillation_trainer.py --config configs/my_experiment.yaml
108+
```
109+
110+
**Config overrides** can be passed via CLI using dotted notation:
111+
112+
```bash
113+
accelerate launch ... distillation_trainer.py \
114+
--config configs/my_experiment.yaml \
115+
++distillation.distillation_alpha=0.6 \
116+
++distillation.quant_cfg=INT8_DEFAULT_CFG \
117+
++optimization.learning_rate=1e-5
118+
```
119+
120+
## Configuration Reference
121+
122+
### Calibration
123+
124+
Before training begins, calibration runs full denoising inference to collect activation statistics for accurate quantizer scales. This is cached as a step-0 checkpoint and reused on subsequent runs.
125+
126+
| Parameter | Default | Description |
127+
|-----------|---------|-------------|
128+
| `calibration_prompts_file` | null | Text file with one prompt per line. Use the HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' if null. |
129+
| `calibration_size` | 128 | Number of prompts (each runs a full denoising loop) |
130+
| `calibration_n_steps` | 30 | Denoising steps per prompt |
131+
| `calibration_guidance_scale` | 4.0 | CFG scale (should match inference-time) |
132+
133+
### Checkpoint Resume
134+
135+
| Parameter | Default | Description |
136+
|-----------|---------|-------------|
137+
| `resume_from_checkpoint` | null | `"latest"` to auto-detect, or explicit path |
138+
| `must_save_by` | null | Minutes after which to save and exit (for Slurm time limits) |
139+
| `restore_quantized_checkpoint` | null | Restore a pre-quantized model (skips calibration) |
140+
| `save_quantized_checkpoint` | null | Path to save the final quantized model |
141+
142+
### Custom Quantization Configs
143+
144+
To define custom quantization configs, add entries to `CUSTOM_QUANT_CONFIGS` in `distillation_trainer.py`:
145+
146+
```python
147+
CUSTOM_QUANT_CONFIGS["MY_FP8_CFG"] = {
148+
"quant_cfg": mtq.FP8_DEFAULT_CFG["quant_cfg"],
149+
"algorithm": "max",
150+
}
151+
```
152+
153+
Then reference it in your YAML: `quant_cfg: MY_FP8_CFG`.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# FSDP Configuration
2+
#
3+
# FULL_SHARD across all GPUs for maximum memory efficiency.
4+
# For multi-node training with `accelerate launch`.
5+
#
6+
# Usage:
7+
# accelerate launch \
8+
# --config_file configs/accelerate/fsdp.yaml \
9+
# --num_processes 16 \
10+
# --num_machines 2 \
11+
# --machine_rank $MACHINE_RANK \
12+
# --main_process_ip $MASTER_IP \
13+
# --main_process_port 29500 \
14+
# distillation_trainer.py --config configs/distillation_example.yaml
15+
16+
distributed_type: FSDP
17+
downcast_bf16: 'no'
18+
enable_cpu_affinity: false
19+
20+
fsdp_config:
21+
# FULL_SHARD: Shard optimizer states, gradients, and parameters across ALL GPUs
22+
# This provides maximum memory efficiency for large models like LTX-2 19B
23+
# Parameters are fully sharded across all nodes (not replicated)
24+
fsdp_sharding_strategy: FULL_SHARD
25+
26+
# Enable activation checkpointing to reduce memory during backward pass
27+
# Critical for 19B model training
28+
fsdp_activation_checkpointing: true
29+
30+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
31+
fsdp_backward_prefetch: BACKWARD_PRE
32+
fsdp_cpu_ram_efficient_loading: true
33+
fsdp_forward_prefetch: false
34+
fsdp_offload_params: false
35+
fsdp_reshard_after_forward: true
36+
fsdp_state_dict_type: SHARDED_STATE_DICT
37+
fsdp_sync_module_states: true
38+
fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock
39+
fsdp_use_orig_params: true
40+
fsdp_version: 1
41+
42+
# Note: num_machines and num_processes are overridden by accelerate launch command-line args
43+
# These are just defaults for local testing
44+
num_machines: 1
45+
num_processes: 8
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# LTX-2 Distillation Training Configuration with ModelOpt
2+
3+
# Model Configuration
4+
model:
5+
# Path to the LTX-2 checkpoint (used for both teacher and student)
6+
model_path: "/path/to/ltx2/checkpoint.safetensors"
7+
8+
# Path to Gemma text encoder (required for LTX-2)
9+
text_encoder_path: "/path/to/gemma/model"
10+
11+
# Training mode: "lora" is not supported yet
12+
training_mode: "full"
13+
14+
# Distillation Configuration
15+
distillation:
16+
# Path to teacher model (if different from model.model_path)
17+
# Set to null to use the same checkpoint as student (loaded without quantization)
18+
teacher_model_path:
19+
20+
# Weight for task loss: L_total = α * L_task + (1-α) * L_distill
21+
# α = 1.0: pure task loss (no distillation)
22+
# α = 0.0: pure distillation loss
23+
distillation_alpha: 0.0
24+
25+
# Type of distillation loss
26+
# "mse": Mean squared error (recommended - transformer outputs are continuous velocity predictions)
27+
# "cosine": Cosine similarity loss (matches direction only, ignores magnitude)
28+
distillation_loss_type: "mse"
29+
30+
# Data type for teacher model (bfloat16 recommended for memory efficiency)
31+
teacher_dtype: "bfloat16"
32+
33+
# ModelOpt Quantization Settings
34+
# Name of the mtq config, e.g. FP8_DEFAULT_CFG, INT8_DEFAULT_CFG, NVFP4_DEFAULT_CFG.
35+
# Custom configs defined in CUSTOM_QUANT_CONFIGS (distillation_trainer.py) are also supported.
36+
quant_cfg:
37+
38+
# Full-inference calibration settings (matching PTQ workflow).
39+
# Each prompt runs a complete denoising loop through the DiT, covering all noise levels.
40+
# Path to a text file with one prompt per line. If null, uses the default
41+
# HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' (same as PTQ).
42+
calibration_prompts_file:
43+
# Total number of calibration prompts (set to 0 to skip calibration)
44+
calibration_size: 128
45+
# Number of denoising steps per prompt (matches PTQ --n-steps)
46+
calibration_n_steps: 30
47+
# CFG guidance scale during calibration (4.0 = PTQ default, calls transformer
48+
# twice per step for positive + negative prompt; 1.0 = no CFG, saves memory)
49+
calibration_guidance_scale: 4.0
50+
51+
# Path to restore a previously quantized model (from mto.save)
52+
restore_quantized_checkpoint:
53+
54+
# Path to save the final quantized model checkpoint
55+
save_quantized_checkpoint:
56+
57+
# Resume from a full training state checkpoint (saves model + optimizer + RNG + step)
58+
# Set to "latest" to auto-find the most recent checkpoint in output_dir/checkpoints/
59+
# Or set to an explicit path like "/path/to/checkpoints/step_001000"
60+
resume_from_checkpoint: latest
61+
62+
# Time-limit-aware saving for Slurm jobs.
63+
# Minutes after which training must save a checkpoint and exit gracefully.
64+
# Set slightly below your Slurm --time limit (e.g. time=30min -> must_save_by: 25).
65+
# Timer starts when train() is called (after model loading/calibration).
66+
must_save_by:
67+
68+
# Debug/Test: Use mock data instead of real preprocessed data
69+
# Useful for testing the training pipeline without preparing a dataset
70+
use_mock_data: false
71+
mock_data_samples: 100
72+
73+
# Training Strategy
74+
training_strategy:
75+
name: "text_to_video"
76+
first_frame_conditioning_p: 0.1
77+
with_audio: false
78+
79+
# Optimization Configuration
80+
optimization:
81+
learning_rate: 2.0e-6
82+
steps: 10000
83+
batch_size: 1
84+
gradient_accumulation_steps: 4
85+
max_grad_norm: 1.0
86+
optimizer_type: "adamw" # # Use "adamw8bit" for memory efficiency
87+
scheduler_type: "cosine"
88+
enable_gradient_checkpointing: true # Essential for memory savings
89+
90+
# Acceleration Configuration
91+
acceleration:
92+
mixed_precision_mode: "bf16"
93+
94+
# NOTE: Set to null - we use ModelOpt quantization instead of ltx-trainer's quanto
95+
quantization:
96+
97+
# 8-bit text encoder for memory savings
98+
load_text_encoder_in_8bit: false
99+
100+
# Data Configuration
101+
data:
102+
# Path to preprocessed training data (created by process_dataset.py)
103+
preprocessed_data_root: "/path/to/preprocessed/data"
104+
num_dataloader_workers: 2
105+
106+
# Validation Configuration
107+
validation:
108+
prompts:
109+
- "A beautiful sunset over the ocean with gentle waves"
110+
- "A cat playing with a ball of yarn in a cozy living room"
111+
negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
112+
video_dims: [512, 320, 33] # [width, height, frames]
113+
frame_rate: 25.0
114+
inference_steps: 30
115+
interval: 500 # Validate every 500 steps
116+
guidance_scale: 4.0
117+
seed: 42
118+
119+
# Checkpointing Configuration
120+
checkpoints:
121+
interval: 1000 # Save checkpoint every 1000 steps
122+
keep_last_n: 3 # Keep only last 3 checkpoints
123+
precision: "bfloat16"
124+
125+
# Weights & Biases Logging
126+
wandb:
127+
enabled: true
128+
project: "ltx2-distillation"
129+
entity: # Your W&B username or team
130+
tags:
131+
- "distillation"
132+
- "modelopt"
133+
log_validation_videos: true
134+
135+
# Flow Matching Configuration
136+
flow_matching:
137+
timestep_sampling_mode: "shifted_logit_normal"
138+
timestep_sampling_params: {}
139+
140+
# General Settings
141+
seed: 42
142+
output_dir: "./outputs/distillation_experiment"

0 commit comments

Comments
 (0)