Conversation
📝 WalkthroughWalkthroughAdds a new LTX-2 Quantization-Aware Distillation (QAD) example for Windows/Linux diffusers: README, dependencies, FSDP and LTX config YAMLs, and a large sample script implementing training, ModelOpt PTQ calibration, distillation, checkpointing, and inference-checkpoint creation. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Trainer as LtxvQADTrainer
participant Quantizer as ModelOpt<br/>Quantizer
participant Student as Student<br/>Model
participant Teacher as Teacher<br/>Model
participant Checkpoint as Checkpoint<br/>Manager
User->>Trainer: start training (config, data)
Trainer->>Student: load student model
Trainer->>Quantizer: run PTQ calibration (calib dataset)
Quantizer->>Student: compute/calibrate amaxs
Quantizer-->>Trainer: return amax metadata
Trainer->>Teacher: load teacher model (if configured)
Trainer->>Student: wrap for distillation
loop per training step
Trainer->>Student: forward(batch)
Student->>Teacher: request teacher outputs
Teacher-->>Student: teacher targets
Student->>Trainer: compute base + KD loss
Trainer->>Trainer: backward & optimizer step
end
Trainer->>Checkpoint: save checkpoint (state, amaxs, filtered keys)
Checkpoint-->>Trainer: checkpoint saved (safetensors)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (2 passed)
✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Comment |
|
To be clear, will this example work for single RTX GPU use-cases (e.g. RTX 5090)? Or, are there any changes / limitations on this part? |
|
b39e963 to
28d6dc1
Compare
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (3)
examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py (2)
383-392: Consider making calibration seed configurable.The calibration seed is hardcoded to 42 (line 385). While this ensures reproducibility, consider making it configurable through the trainer config for users who want different calibration orderings.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py` around lines 383 - 392, The hardcoded calibration seed (torch.manual_seed(42)) should be made configurable: add a calibration seed option in the trainer/config (e.g., self._config.optimization.calibration_seed), read that value (fall back to 42 if not set), and call torch.manual_seed(calibration_seed) before creating the calib_loader DataLoader (which uses self._config.optimization.batch_size); this preserves reproducibility while allowing users to change calibration ordering.
216-228: Nested dict handling is limited to one level.
move_batch_to_deviceonly handles one level of nested dictionaries. If batch data has deeper nesting, inner tensors won't be moved to device. This may be intentional given the expected data structure, but worth noting.♻️ Optional: recursive implementation
def move_batch_to_device(batch: dict, device: torch.device) -> dict: """Recursively move batch tensors to device.""" + def _move(obj): + if isinstance(obj, dict): + return {k: _move(v) for k, v in obj.items()} + elif isinstance(obj, torch.Tensor): + return obj.to(device) + return obj + return _move(batch) - result = {} - for k, v in batch.items(): - if isinstance(v, dict): - result[k] = { - ik: iv.to(device) if isinstance(iv, torch.Tensor) else iv for ik, iv in v.items() - } - elif isinstance(v, torch.Tensor): - result[k] = v.to(device) - else: - result[k] = v - return result🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py` around lines 216 - 228, The function move_batch_to_device only handles one-level nested dicts so tensors in deeper nested structures won't be moved; update move_batch_to_device to recursively traverse arbitrary nested containers (dicts, lists, tuples) and call .to(device) on any torch.Tensor encountered, preserving container types and non-tensor values, e.g., by replacing the current one-level dict branch (isinstance(v, dict)) with a recursive helper or making move_batch_to_device itself recurse for dict/list/tuple cases and return the same structure with tensors moved to the provided device.examples/windows/torch_onnx/diffusers/qad_example/README.md (1)
99-111: Minor style: consider simplifying phrasing.The static analysis tool suggests "All of the following" could be simplified to "All the following" for conciseness. This is optional.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/windows/torch_onnx/diffusers/qad_example/README.md` around lines 99 - 111, Replace the verbose phrasing "All of the following" in the README surrounding the QAD/optimization table with the shorter "All the following"; search the README.md near the QAD section (keys like `qad`, `calib_size`, `kd_loss_weight`, `exclude_blocks`, `skip_inference_ckpt`) and update the sentence so it reads "All the following" while preserving surrounding punctuation and formatting.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/windows/torch_onnx/diffusers/qad_example/ltx2_qad.yaml`:
- Around line 70-74: The YAML default settings in ltx2_qad.yaml (calib_size: 10,
kd_loss_weight: 1.0) conflict with the README documented defaults (calib_size:
512, kd_loss_weight: 0.5); pick one source of truth and make them consistent:
either update the YAML parameters calib_size to 512 and kd_loss_weight to 0.5 to
match the README, or update the README entries for calib_size and kd_loss_weight
to reflect the YAML values (10 and 1.0), then run a quick grep through
docs/usage to ensure no other references remain inconsistent.
- Around line 2-6: The YAML contains hardcoded internal Lustre paths for
model.model_path and model.text_encoder_path; replace those absolute internal
filesystem values with clear placeholder paths (e.g.,
"<PATH_TO_MODEL_CHECKPOINT>.safetensors" and "<PATH_TO_TEXT_ENCODER_DIR>") so
external users can update them; update the values for the keys model.model_path
and model.text_encoder_path in ltx2_qad.yaml and keep training_mode and
load_checkpoint unchanged.
- Around line 28-30: Replace the hard-coded internal Lustre path assigned to the
YAML key preprocessed_data_root with a neutral placeholder or environment
variable reference (for example "/path/to/preprocessed_data" or
"${PREPROCESSED_DATA_ROOT}") so external users can supply their own dataset
location; update the value of preprocessed_data_root in ltx2_qad.yaml
accordingly.
In
`@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py`:
- Around line 942-961: The CLI-vs-YAML override wrongly treats explicit CLI
values equal to hardcoded defaults as "not provided"; change the argparse
defaults to sentinel (e.g., None) and update the resolution code to check for
None instead of comparing to hardcoded defaults: replace the current assignments
for calib_size, kd_loss_weight, and exclude_blocks to pick args.<name> if
args.<name> is not None else qad_config.get("<name>", <hardcoded_default>) so
explicit CLI values (including previous defaults) correctly override YAML;
reference the symbols args, calib_size, kd_loss_weight, and exclude_blocks when
making the edits.
- Around line 134-154: The call to torch.load in load_state_dict_any_format uses
weights_only=False without justification; update this by either (A) adding an
inline comment immediately next to the torch.load call explaining why
deserializing pickles here is safe (e.g., sources are internal/trusted and not
user-supplied), or (B) change the call to torch.load(..., weights_only=True) and
adjust subsequent logic to handle the restricted load, or (C) prefer/require
safetensors (detect_format fallback) for untrusted inputs and log/warn or raise
if a torch pickle is provided from an untrusted path; reference the torch.load
invocation and the weights_only parameter in your change.
---
Nitpick comments:
In `@examples/windows/torch_onnx/diffusers/qad_example/README.md`:
- Around line 99-111: Replace the verbose phrasing "All of the following" in the
README surrounding the QAD/optimization table with the shorter "All the
following"; search the README.md near the QAD section (keys like `qad`,
`calib_size`, `kd_loss_weight`, `exclude_blocks`, `skip_inference_ckpt`) and
update the sentence so it reads "All the following" while preserving surrounding
punctuation and formatting.
In
`@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py`:
- Around line 383-392: The hardcoded calibration seed (torch.manual_seed(42))
should be made configurable: add a calibration seed option in the trainer/config
(e.g., self._config.optimization.calibration_seed), read that value (fall back
to 42 if not set), and call torch.manual_seed(calibration_seed) before creating
the calib_loader DataLoader (which uses self._config.optimization.batch_size);
this preserves reproducibility while allowing users to change calibration
ordering.
- Around line 216-228: The function move_batch_to_device only handles one-level
nested dicts so tensors in deeper nested structures won't be moved; update
move_batch_to_device to recursively traverse arbitrary nested containers (dicts,
lists, tuples) and call .to(device) on any torch.Tensor encountered, preserving
container types and non-tensor values, e.g., by replacing the current one-level
dict branch (isinstance(v, dict)) with a recursive helper or making
move_batch_to_device itself recurse for dict/list/tuple cases and return the
same structure with tensors moved to the provided device.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between b39e963d65dfed2a22a963db7d1e789abf1b167d and 28d6dc1.
📒 Files selected for processing (5)
examples/windows/torch_onnx/diffusers/qad_example/README.mdexamples/windows/torch_onnx/diffusers/qad_example/fsdp_custom.yamlexamples/windows/torch_onnx/diffusers/qad_example/ltx2_qad.yamlexamples/windows/torch_onnx/diffusers/qad_example/requirements.txtexamples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/windows/torch_onnx/diffusers/qad_example/requirements.txt
- examples/windows/torch_onnx/diffusers/qad_example/fsdp_custom.yaml
| model: | ||
| model_path: "/lustre/fsw/portfolios/adlr/projects/adlr_psx_numerics/users/ynankani/ComfyUI/models/checkpoints/ltx-av-step-1933500-split-new-vae.safetensors" | ||
| training_mode: "full" | ||
| load_checkpoint: | ||
| text_encoder_path: "/lustre/fsw/portfolios/adlr/users/dhutchins/models/gemma" |
There was a problem hiding this comment.
Replace internal filesystem paths with placeholders.
The config contains hardcoded paths to an internal Lustre filesystem that won't work for external users:
model_path:/lustre/fsw/.../ltx-av-step-1933500-split-new-vae.safetensorstext_encoder_path:/lustre/fsw/.../gemma
Replace with placeholder paths that users can easily identify and update.
📝 Suggested placeholder paths
model:
- model_path: "/lustre/fsw/portfolios/adlr/projects/adlr_psx_numerics/users/ynankani/ComfyUI/models/checkpoints/ltx-av-step-1933500-split-new-vae.safetensors"
+ model_path: "/path/to/ltx2/checkpoint.safetensors" # TODO: Set your LTX-2 checkpoint path
training_mode: "full"
load_checkpoint:
- text_encoder_path: "/lustre/fsw/portfolios/adlr/users/dhutchins/models/gemma"
+ text_encoder_path: "/path/to/gemma" # TODO: Set your Gemma text encoder path🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/windows/torch_onnx/diffusers/qad_example/ltx2_qad.yaml` around lines
2 - 6, The YAML contains hardcoded internal Lustre paths for model.model_path
and model.text_encoder_path; replace those absolute internal filesystem values
with clear placeholder paths (e.g., "<PATH_TO_MODEL_CHECKPOINT>.safetensors" and
"<PATH_TO_TEXT_ENCODER_DIR>") so external users can update them; update the
values for the keys model.model_path and model.text_encoder_path in
ltx2_qad.yaml and keep training_mode and load_checkpoint unchanged.
| data: | ||
| preprocessed_data_root: "/lustre/fsw/portfolios/adlr/users/scavallari/ltx-qad/qad-dataset" | ||
| num_dataloader_workers: 2 |
There was a problem hiding this comment.
Replace internal data path with placeholder.
The preprocessed_data_root path is also an internal Lustre path that external users cannot access.
📝 Suggested fix
data:
- preprocessed_data_root: "/lustre/fsw/portfolios/adlr/users/scavallari/ltx-qad/qad-dataset"
+ preprocessed_data_root: "/path/to/preprocessed" # TODO: Set your preprocessed dataset path
num_dataloader_workers: 2📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| data: | |
| preprocessed_data_root: "/lustre/fsw/portfolios/adlr/users/scavallari/ltx-qad/qad-dataset" | |
| num_dataloader_workers: 2 | |
| data: | |
| preprocessed_data_root: "/path/to/preprocessed" # TODO: Set your preprocessed dataset path | |
| num_dataloader_workers: 2 |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/windows/torch_onnx/diffusers/qad_example/ltx2_qad.yaml` around lines
28 - 30, Replace the hard-coded internal Lustre path assigned to the YAML key
preprocessed_data_root with a neutral placeholder or environment variable
reference (for example "/path/to/preprocessed_data" or
"${PREPROCESSED_DATA_ROOT}") so external users can supply their own dataset
location; update the value of preprocessed_data_root in ltx2_qad.yaml
accordingly.
| qad: | ||
| calib_size: 10 | ||
| kd_loss_weight: 1.0 | ||
| exclude_blocks: [0, 1, 46, 47] | ||
| skip_inference_ckpt: false |
There was a problem hiding this comment.
Documentation mismatch: YAML defaults differ from README.
The YAML default values don't match the README documentation (lines 101-102):
| Parameter | YAML default | README states |
|---|---|---|
calib_size |
10 | 512 |
kd_loss_weight |
1.0 | 0.5 |
Either update the YAML to match the README or update the README to reflect the actual defaults.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/windows/torch_onnx/diffusers/qad_example/ltx2_qad.yaml` around lines
70 - 74, The YAML default settings in ltx2_qad.yaml (calib_size: 10,
kd_loss_weight: 1.0) conflict with the README documented defaults (calib_size:
512, kd_loss_weight: 0.5); pick one source of truth and make them consistent:
either update the YAML parameters calib_size to 512 and kd_loss_weight to 0.5 to
match the README, or update the README entries for calib_size and kd_loss_weight
to reflect the YAML values (10 and 1.0), then run a quick grep through
docs/usage to ensure no other references remain inconsistent.
| def load_state_dict_any_format(path: str, label: str = "") -> tuple[dict, dict | None]: | ||
| """Load state dict from either torch pickle or safetensors.""" | ||
| fmt = detect_format(path) | ||
| logger.info(f"[{label}] Detected format: {fmt} for {path}") | ||
|
|
||
| if fmt == "torch": | ||
| raw = torch.load(path, map_location="cpu", weights_only=False) | ||
| if isinstance(raw, dict) and "state_dict" in raw: | ||
| return raw["state_dict"], None | ||
| return raw, None | ||
| else: | ||
| try: | ||
| from safetensors.torch import load_file, safe_open | ||
|
|
||
| with safe_open(path, framework="pt", device="cpu") as f: | ||
| metadata = f.metadata() or {} | ||
| return load_file(path, device="cpu"), metadata | ||
| except Exception as e: | ||
| logger.warning(f"safe_open failed ({e}), trying manual parse...") | ||
| return _load_safetensors_manual(path) | ||
|
|
There was a problem hiding this comment.
CRITICAL: torch.load with weights_only=False requires justification comment.
Per the security coding guidelines in SECURITY.md, using torch.load(..., weights_only=False) without an inline comment justifying why it is safe is a critical security issue. Pickle deserialization can execute arbitrary code, and this function loads user-provided checkpoint paths.
Either:
- Add an inline comment explaining why this is safe (e.g., files are internally-generated and not user-supplied), OR
- Use
weights_only=Trueand handle the restricted loading, OR - Prefer safetensors format which doesn't have this vulnerability
🔒 Option 1: Add justification comment
if fmt == "torch":
- raw = torch.load(path, map_location="cpu", weights_only=False)
+ # Safe: Checkpoint paths are provided by the user who controls their own training pipeline.
+ # Users are responsible for ensuring checkpoint provenance.
+ raw = torch.load(path, map_location="cpu", weights_only=False)
if isinstance(raw, dict) and "state_dict" in raw:
return raw["state_dict"], None
return raw, None🔒 Option 2: Add warning about untrusted files
def load_state_dict_any_format(path: str, label: str = "") -> tuple[dict, dict | None]:
- """Load state dict from either torch pickle or safetensors."""
+ """Load state dict from either torch pickle or safetensors.
+
+ WARNING: Torch pickle format uses unsafe deserialization. Only load
+ checkpoints from trusted sources. Prefer safetensors format when possible.
+ """
fmt = detect_format(path)
logger.info(f"[{label}] Detected format: {fmt} for {path}")
if fmt == "torch":
+ logger.warning(
+ f"Loading torch pickle checkpoint from {path}. "
+ "Only load checkpoints from trusted sources."
+ )
raw = torch.load(path, map_location="cpu", weights_only=False)As per coding guidelines: "torch.load(..., weights_only=False) with no inline comment justifying why it is safe" must be flagged as a CRITICAL security issue.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py`
around lines 134 - 154, The call to torch.load in load_state_dict_any_format
uses weights_only=False without justification; update this by either (A) adding
an inline comment immediately next to the torch.load call explaining why
deserializing pickles here is safe (e.g., sources are internal/trusted and not
user-supplied), or (B) change the call to torch.load(..., weights_only=True) and
adjust subsequent logic to handle the restricted load, or (C) prefer/require
safetensors (detect_format fallback) for untrusted inputs and log/warn or raise
if a torch pickle is provided from an untrusted path; reference the torch.load
invocation and the weights_only parameter in your change.
| import yaml | ||
|
|
||
| with open(args.config) as f: | ||
| config_dict = yaml.safe_load(f) | ||
|
|
||
| # Extract QAD-specific config (not part of LtxTrainerConfig) | ||
| qad_config = config_dict.pop("qad", {}) | ||
|
|
||
| config = LtxTrainerConfig(**config_dict) | ||
|
|
||
| # Resolve QAD params: CLI args override YAML values, YAML overrides defaults | ||
| calib_size = args.calib_size if args.calib_size != 512 else qad_config.get("calib_size", 512) | ||
| kd_loss_weight = ( | ||
| args.kd_loss_weight if args.kd_loss_weight != 0.5 else qad_config.get("kd_loss_weight", 0.5) | ||
| ) | ||
| exclude_blocks = ( | ||
| args.exclude_blocks | ||
| if args.exclude_blocks != [0, 1, 46, 47] | ||
| else qad_config.get("exclude_blocks", [0, 1, 46, 47]) | ||
| ) |
There was a problem hiding this comment.
CLI override logic has edge case: user explicitly passing default value.
The override logic checks if CLI args differ from hardcoded defaults to determine whether to use CLI or YAML values:
calib_size = args.calib_size if args.calib_size != 512 else qad_config.get("calib_size", 512)If a user explicitly passes --calib-size 512 to override a YAML value of 256, the YAML value (256) would be used instead. Consider using argparse's default tracking or explicit "was this arg provided" checking.
♻️ Suggested fix using sentinel values
train_parser.add_argument(
"--calib-size",
type=int,
- default=512,
+ default=None,
help="Number of calibration batches for PTQ",
)
train_parser.add_argument(
"--kd-loss-weight",
type=float,
- default=0.5,
+ default=None,
help="KD loss weight (0=pure hard loss, 1=pure KD loss)",
)
train_parser.add_argument(
"--exclude-blocks",
type=int,
nargs="*",
- default=[0, 1, 46, 47],
+ default=None,
help="Transformer block indices to exclude from quantization",
)Then in main:
- calib_size = args.calib_size if args.calib_size != 512 else qad_config.get("calib_size", 512)
+ calib_size = args.calib_size if args.calib_size is not None else qad_config.get("calib_size", 512)
- kd_loss_weight = (
- args.kd_loss_weight if args.kd_loss_weight != 0.5 else qad_config.get("kd_loss_weight", 0.5)
- )
+ kd_loss_weight = (
+ args.kd_loss_weight if args.kd_loss_weight is not None else qad_config.get("kd_loss_weight", 0.5)
+ )
- exclude_blocks = (
- args.exclude_blocks
- if args.exclude_blocks != [0, 1, 46, 47]
- else qad_config.get("exclude_blocks", [0, 1, 46, 47])
- )
+ exclude_blocks = (
+ args.exclude_blocks
+ if args.exclude_blocks is not None
+ else qad_config.get("exclude_blocks", [0, 1, 46, 47])
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py`
around lines 942 - 961, The CLI-vs-YAML override wrongly treats explicit CLI
values equal to hardcoded defaults as "not provided"; change the argparse
defaults to sentinel (e.g., None) and update the resolution code to check for
None instead of comparing to hardcoded defaults: replace the current assignments
for calib_size, kd_loss_weight, and exclude_blocks to pick args.<name> if
args.<name> is not None else qad_config.get("<name>", <hardcoded_default>) so
explicit CLI values (including previous defaults) correctly override YAML;
reference the symbols args, calib_size, kd_loss_weight, and exclude_blocks when
making the edits.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #933 +/- ##
=======================================
Coverage 72.13% 72.13%
=======================================
Files 209 209
Lines 23631 23631
=======================================
Hits 17046 17046
Misses 6585 6585 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
vishalpandya1990
left a comment
There was a problem hiding this comment.
Add a note about sample-script for illustrating pipeline, and Linux RTX verification of it.
Signed-off-by: ynankani <ynankani@nvidia.com>
Signed-off-by: ynankani <ynankani@nvidia.com>
Signed-off-by: ynankani <ynankani@nvidia.com>
Signed-off-by: ynankani <ynankani@nvidia.com>
Signed-off-by: ynankani <ynankani@nvidia.com>
Signed-off-by: ynankani <ynankani@nvidia.com>
28d6dc1 to
bdf2f3e
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (2)
examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py (2)
139-143:⚠️ Potential issue | 🔴 CriticalCritical: unsafe pickle deserialization in checkpoint loading path.
torch.load(..., weights_only=False)is used on file paths that come from CLI inputs, which keeps pickle object deserialization enabled.🔒 Suggested fix
- raw = torch.load(path, map_location="cpu", weights_only=False) + raw = torch.load(path, map_location="cpu", weights_only=True)As per coding guidelines: "Do not use
torch.load(..., weights_only=False)unless a documented exception is provided."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py` around lines 139 - 143, The code currently calls torch.load(path, map_location="cpu", weights_only=False) which allows unsafe pickle deserialization; change the call to torch.load(path, map_location="cpu", weights_only=True) to avoid untrusted pickle deserialization when loading CLI-provided checkpoints (keep the existing handling of the returned object and "state_dict" key), and add a short fallback/error path if weights_only=True fails for an older checkpoint format so we fail safe rather than re-enabling pickle.
845-861:⚠️ Potential issue | 🟡 MinorCLI override precedence still drops explicit default-valued args.
If a user explicitly passes
--calib-size 512(or equivalent default), YAML still wins due to value-comparison logic.♻️ Suggested fix
train_parser.add_argument( "--calib-size", type=int, - default=512, + default=None, help="Number of calibration batches for PTQ", ) train_parser.add_argument( "--kd-loss-weight", type=float, - default=0.5, + default=None, help="KD loss weight (0=pure hard loss, 1=pure KD loss)", ) train_parser.add_argument( "--exclude-blocks", type=int, nargs="*", - default=[0, 1, 46, 47], + default=None, help="Transformer block indices to exclude from quantization", ) @@ - calib_size = args.calib_size if args.calib_size != 512 else qad_config.get("calib_size", 512) + calib_size = args.calib_size if args.calib_size is not None else qad_config.get("calib_size", 512) kd_loss_weight = ( - args.kd_loss_weight if args.kd_loss_weight != 0.5 else qad_config.get("kd_loss_weight", 0.5) + args.kd_loss_weight if args.kd_loss_weight is not None else qad_config.get("kd_loss_weight", 0.5) ) exclude_blocks = ( - args.exclude_blocks - if args.exclude_blocks != [0, 1, 46, 47] - else qad_config.get("exclude_blocks", [0, 1, 46, 47]) + args.exclude_blocks if args.exclude_blocks is not None else qad_config.get("exclude_blocks", [0, 1, 46, 47]) )Also applies to: 953-961
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py` around lines 845 - 861, The CLI currently treats arguments equal to their default as "not provided" so YAML values win; change the parser so explicitly-passed defaults are preserved by constructing the ArgumentParser with argument_default=argparse.SUPPRESS (or otherwise suppressing defaults) and then checking for presence in vars(args) when merging CLI→YAML; update the parser that defines train_parser (and any other parser instances that add --calib-size, --kd-loss-weight, --exclude-blocks and the similar args around lines 953-961) so add_argument calls remain the same but the parser is created with argument_default=argparse.SUPPRESS and the merge logic uses "if 'calib_size' in vars(args)" (and analogous keys for kd_loss_weight and exclude_blocks) to respect explicit CLI values even when they equal the default.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/windows/torch_onnx/diffusers/qad_example/README.md`:
- Around line 66-70: The README has a filename mismatch: the prose references
preprocess_dataset.py while the example command runs scripts/process_dataset.py;
make them consistent by updating either the descriptive text or the command so
both refer to the same script name (choose one: preprocess_dataset.py or
scripts/process_dataset.py) and ensure the example invocation (the bash block)
uses that exact filename along with the shown arguments (e.g.,
--resolution-buckets 384x256x97), so copy-paste users won’t fail.
- Line 3: Update the note sentence that reads "This is a sample script...
verified to run on a Linux RTX 5090 system, but runs into OOM" to explicitly
state support status: indicate that the example can run for small-scale
inference or demo on Linux RTX 5090 but is not supported for full training due
to OOM (i.e., full training is unsupported on a single RTX 5090), and add a
separate short clause calling out the Windows text-encoder load failure as a
known limitation; ensure the revised README message clearly differentiates
"works for demo/inference" vs "not suitable for full training" and references
the Windows text-encoder issue.
In
`@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py`:
- Around line 386-392: The calibration loop can crash from an unhandled
StopIteration because DataLoader with drop_last=True may yield zero batches;
update the logic around calib_loader, calib_steps and data_iter to (1) compute
calib_steps from the loader itself (e.g., use len(calib_loader) or
floor(len(dataset)/batch_size when drop_last=True) instead of len(dataset) so
steps match actual batches) and (2) robustly handle StopIteration by checking if
len(calib_loader) == 0 and skipping calibration, or by wrapping both attempts to
call next(data_iter) (the code using data_iter and next(data_iter)) in a
try-except that recreates the iterator and if still raises StopIteration, exits
the calibration loop gracefully rather than assuming a second next will succeed.
---
Duplicate comments:
In
`@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py`:
- Around line 139-143: The code currently calls torch.load(path,
map_location="cpu", weights_only=False) which allows unsafe pickle
deserialization; change the call to torch.load(path, map_location="cpu",
weights_only=True) to avoid untrusted pickle deserialization when loading
CLI-provided checkpoints (keep the existing handling of the returned object and
"state_dict" key), and add a short fallback/error path if weights_only=True
fails for an older checkpoint format so we fail safe rather than re-enabling
pickle.
- Around line 845-861: The CLI currently treats arguments equal to their default
as "not provided" so YAML values win; change the parser so explicitly-passed
defaults are preserved by constructing the ArgumentParser with
argument_default=argparse.SUPPRESS (or otherwise suppressing defaults) and then
checking for presence in vars(args) when merging CLI→YAML; update the parser
that defines train_parser (and any other parser instances that add --calib-size,
--kd-loss-weight, --exclude-blocks and the similar args around lines 953-961) so
add_argument calls remain the same but the parser is created with
argument_default=argparse.SUPPRESS and the merge logic uses "if 'calib_size' in
vars(args)" (and analogous keys for kd_loss_weight and exclude_blocks) to
respect explicit CLI values even when they equal the default.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 11b1053f-f84a-482c-853d-306709d229d2
📒 Files selected for processing (5)
examples/windows/torch_onnx/diffusers/qad_example/README.mdexamples/windows/torch_onnx/diffusers/qad_example/fsdp_custom.yamlexamples/windows/torch_onnx/diffusers/qad_example/ltx2_qad.yamlexamples/windows/torch_onnx/diffusers/qad_example/requirements.txtexamples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py
🚧 Files skipped from review as they are similar to previous changes (3)
- examples/windows/torch_onnx/diffusers/qad_example/ltx2_qad.yaml
- examples/windows/torch_onnx/diffusers/qad_example/requirements.txt
- examples/windows/torch_onnx/diffusers/qad_example/fsdp_custom.yaml
| @@ -0,0 +1,162 @@ | |||
| # LTX-2 QAD Example (Quantization-Aware Distillation) | |||
|
|
|||
| **Note:** This is a **sample script for illustrating the QAD pipeline**. It has been verified to run on a **Linux RTX 5090** system, but runs into **OOM (Out of Memory)** on that configuration. | |||
There was a problem hiding this comment.
Clarify the single-RTX support statement.
“Verified to run on Linux RTX 5090” conflicts with “runs into OOM” in the same sentence. Please state explicitly whether this setup is unsupported for full training (and note the Windows text-encoder load failure as a known limitation).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/windows/torch_onnx/diffusers/qad_example/README.md` at line 3,
Update the note sentence that reads "This is a sample script... verified to run
on a Linux RTX 5090 system, but runs into OOM" to explicitly state support
status: indicate that the example can run for small-scale inference or demo on
Linux RTX 5090 but is not supported for full training due to OOM (i.e., full
training is unsupported on a single RTX 5090), and add a separate short clause
calling out the Windows text-encoder load failure as a known limitation; ensure
the revised README message clearly differentiates "works for demo/inference" vs
"not suitable for full training" and references the Windows text-encoder issue.
| Run the LTX preprocessing script to extract latents and text embeddings from your videos. Use `preprocess_dataset.py` with the following arguments (matching the LTX training pipeline): | ||
|
|
||
| ```bash | ||
| python scripts/process_dataset.py /path/to/dataset.json \ | ||
| --resolution-buckets 384x256x97 \ |
There was a problem hiding this comment.
Dataset preprocessing command name is inconsistent.
The text says preprocess_dataset.py, but the command uses scripts/process_dataset.py. Please align naming to avoid copy-paste failures.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/windows/torch_onnx/diffusers/qad_example/README.md` around lines 66
- 70, The README has a filename mismatch: the prose references
preprocess_dataset.py while the example command runs scripts/process_dataset.py;
make them consistent by updating either the descriptive text or the command so
both refer to the same script name (choose one: preprocess_dataset.py or
scripts/process_dataset.py) and ensure the example invocation (the bash block)
uses that exact filename along with the shown arguments (e.g.,
--resolution-buckets 384x256x97), so copy-paste users won’t fail.
| calib_loader = DataLoader( | ||
| dataset, | ||
| batch_size=self._config.optimization.batch_size, | ||
| shuffle=False, | ||
| num_workers=0, | ||
| drop_last=True, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
wc -l examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.pyRepository: NVIDIA/Model-Optimizer
Length of output: 152
🏁 Script executed:
sed -n '380,425p' examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1786
Uncaught StopIteration exception crashes calibration loop when dataset size is smaller than batch size.
The calibration loop exception handling at lines 406–408 has a critical flaw: when drop_last=True and len(dataset) < batch_size, the DataLoader produces zero batches. The first next(data_iter) raises StopIteration, the except block catches it and recreates the iterator, but the second next(data_iter) on line 408 is not wrapped in a try-except and will raise StopIteration again, crashing the calibration without a handler.
Additionally, line 402 uses len(dataset) to compute calib_steps, but with drop_last=True, the actual number of batches may differ from dataset size.
Suggested fix
calib_loader = DataLoader(
dataset,
batch_size=self._config.optimization.batch_size,
shuffle=False,
num_workers=0,
- drop_last=True,
+ drop_last=False,
)
+ if len(calib_loader) == 0:
+ logger.error("Calibration loader is empty. Check dataset size and batch_size.")
+ return
-
- calib_steps = min(self._calib_size, len(dataset))
+ calib_steps = min(self._calib_size, len(calib_loader))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py`
around lines 386 - 392, The calibration loop can crash from an unhandled
StopIteration because DataLoader with drop_last=True may yield zero batches;
update the logic around calib_loader, calib_steps and data_iter to (1) compute
calib_steps from the loader itself (e.g., use len(calib_loader) or
floor(len(dataset)/batch_size when drop_last=True) instead of len(dataset) so
steps match actual batches) and (2) robustly handle StopIteration by checking if
len(calib_loader) == 0 and skipping calibration, or by wrapping both attempts to
call next(data_iter) (the code using data_iter and next(data_iter)) in a
try-except that recreates the iterator and if still raises StopIteration, exits
the calibration loop gracefully rather than assuming a second next will succeed.
What does this PR do?
sample QAD example script
Type of change: ? new example
Example script for QAD on diffusion model like ltx-2
Overview: ?
Usage
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Documentation
New Features