Skip to content

Auto Quantize improvements and bug fixes for large sparse MoEs#953

Merged
realAsma merged 4 commits intomainfrom
asma/nemotron_mixed
Mar 10, 2026
Merged

Auto Quantize improvements and bug fixes for large sparse MoEs#953
realAsma merged 4 commits intomainfrom
asma/nemotron_mixed

Conversation

@realAsma
Copy link
Copy Markdown
Contributor

@realAsma realAsma commented Mar 2, 2026

What does this PR do?

Type of change: New feature + Bug fixes

Overview:

Enable AutoQuantize for NemotronH and large SparseMoE models, and update the FP8 workflow split between mtq.auto_quantize and mtq.quantize.

mtq.auto_quantize is now positioned as the lightweight search phase (lite calibration + scoring), while mtq.quantize is used for heavier/final calibration workflows (longer calibration passes, force-all-token style MoE calibration, and advanced recipes such as GPTQ, MSE, etc.).

Algorithm & feature changes

  • NemotronH / SparseMoE support: Updated quant_module and score_module rules (should eventually move to the proposed modeling lib). In future, this should be the only change needed to support new models — the bug fixes below were unearthed while enabling NemotronH
  • Config generation: Added mtq.get_auto_quantize_config(search_state, constraints=None, verbose=False) to re-solve from search_state and produce plain-dict configs (no redundant output_quantizer), with optional verbose summary
  • FP8 workflow split: Use lite calibration in mtq.auto_quantize, then run longer/final calibration with mtq.quantize using the generated config
  • Performance: Pass name_to_module to enable_weight_access_and_writeback to avoid O(N^2) overhead on large MoE models
  • Calibration caching in checkpoint: Save/restore quantizer calibration states (metadata + state_dict) per recipe in the AutoQuantize checkpoint, so resuming a search skips redundant calibration
  • Per-rank distributed checkpointing: When torch.distributed is initialized, each rank saves/loads its own checkpoint file (search_state{rank}.pt), with backward-compatible fallback to the single-file path

API updates

  • Config API naming: Use mtq.get_auto_quantize_config(...) for exporting the searched recipe into a quantize-ready config
  • Recommended usage pattern:
# 1) Lightweight search + lite calibration
model, search_state = mtq.auto_quantize(
    model,
    constraints={"effective_bits": 6.0},
    quantization_formats=[mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG],
    data_loader=data_loader,
    forward_step=forward_step,
    loss_func=loss_func,
    num_calib_steps=64,   # lite calibration during search
    num_score_steps=128,
)

# 2) Export searched config (optionally re-solve constraints)
auto_quantize_config = mtq.get_auto_quantize_config(
    search_state,
    constraints={"effective_bits": 6.0},
    verbose=True,
)

# 3) Final / longer calibration pass with quantize
model = mtq.quantize(
    model,
    config=auto_quantize_config,
    forward_loop=long_calibration_loop,  # e.g. force-all-token style MoE calibration
)

Bug fixes

  • Fixed disabled_layers handling so fused kernels (e.g. Mamba blocks) are properly skipped
  • Fixed gradient checkpointing to keep all modules except the checkpointed modules in eval
  • Fixed FP8 fake quant NaN/inf when amax ≈ 0
  • Fixed SequentialQuantizer.convert_to_single_quantizer to operate on module instead of model, avoiding O(N^2) CPU iteration on SparseMoE models with 1000s of submodules
  • Switched to proper F.kl_div for KL divergence scoring

Not yet exposed to llm_ptq

mtq.get_auto_quantize_config is not yet wired into llm_ptq. The plain config records per-expert quantization settings for all MoE experts, resulting in large JSON files. For my experiments I used a quick workaround. A follow-up PR will add a better config representation and expose it to llm_ptq.

Testing

  • Tested on NemotronH-tiny and Nemotron-Super-RL models
  • Verified auto_quantize scoring + config generation end-to-end
  • Unit test for checkpoint resume verifies calibration cache correctness (metadata + tensor values)
  • Existing unit tests pass

Before your PR is "Ready for review"

  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: AutoQuantize end-to-end requires GPU + large MoE models; verified manually on NemotronH-tiny and Nemotron-Super-RL. Unit test coverage to follow.
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?: Yes

Additional Information

Follow-up planned: expose mtq.get_auto_quantize_config to llm_ptq with a compact config format for MoE models. AWQ support in AutoQuantize can also be removed in a future PR to keep it lightweight.

Summary by CodeRabbit

  • New Features

    • Added get_auto_quantize_config() API to extract and reapply quantization configs; auto-quantize now persists and restores per-module calibration state and method.
    • Exposed config re-apply path and added validation for supported auto-quantize algorithms.
  • Bug Fixes

    • FP8 scaling now guards zero amax and clamps pre-conversion values to avoid NaNs/Infs.
  • Compatibility

    • Utility signature updated to accept a precomputed module map for performance.
  • Tests

    • Added tests for zero-amax handling, idempotent quantization, and auto-quantize config/resume behavior.

@realAsma realAsma requested a review from a team as a code owner March 2, 2026 21:15
@realAsma realAsma requested a review from meenchen March 2, 2026 21:15
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 2, 2026

📝 Walkthrough

Walkthrough

Adds auto-quantize method tracking and per-quantizer checkpoint persistence, provides config reconstruction (get_auto_quantize_config), centralizes name_to_module propagation in calibration, guards FP8 scaling against zero amax, adjusts gradient-checkpointing mode handling, and expands related tests.

Changes

Cohort / File(s) Summary
Auto-quantize search & helpers
modelopt/torch/quantization/algorithms.py
Track search method_name, persist per-quantizer quantizer_states in checkpoints, add looser checkpoint load, pass quant_module_names through recipe hparams/candidate stats, refactor KL-divergence log-prob paths, and add helpers: get_auto_quantize_config, _resolve_best_recipe, _match_quantizer_cfg, _print_recipe_summary.
Model quantization API
modelopt/torch/quantization/model_quant.py
Expose public get_auto_quantize_config wrapper, validate supported auto-quantize algorithms, apply config when model already quantized, add _AUTO_QUANTIZE_SUPPORTED_ALGORITHMS, update __all__.
Calibration mapping & utils
modelopt/torch/quantization/model_calib.py, modelopt/torch/quantization/utils.py
Introduce and propagate precomputed name_to_module dict through calibration flows; update calls to enable_weight_access_and_writeback(module, model, name_to_module) and document mapping to avoid repeated named lookups.
FP8 scaling safety
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu, modelopt/torch/quantization/tensor_quant.py
Add epsilon-based near-zero amax detection and safe_amax substitution (use 1 for tiny amax), compute scale from safe_amax, and clamp pre-FP8 values to [-448, 448] to avoid NaN/Inf.
Gradient-checkpointing mode handling
modelopt/torch/quantization/plugins/huggingface.py
Replace per-module Dropout toggles with per-module mode: modules exposing gradient_checkpointing set to train(), others set to eval(); keep gradient checkpointing enabled.
Searcher checkpointing & paths
modelopt/torch/opt/searcher.py
Add per-rank checkpoint resolver _get_checkpoint_path, make load_search_checkpoint(strict=True) support non-strict/partial-key loading, and improve save/load resilience using default_state_dict defaults.
Tests
tests/gpu/torch/quantization/test_tensor_quant_cuda.py, tests/unit/torch/quantization/test_autoquant.py, tests/unit/torch/quantization/test_quantize_cpu.py
Add FP8 zero-amax tests (uniform and per-channel), tests for selective-layer auto-quantize & get_auto_quantize_config application, extend checkpoint resume assertions (method and quantizer_states), and add idempotence test for repeated quantize.
Changelog
CHANGELOG.rst
Document get_auto_quantize_config API, improved checkpoint/restore for calibration state, and MoE grouping support notes.

Sequence Diagram(s)

sequenceDiagram
    participant Caller as Caller
    participant Searcher as AutoQuantizeSearcher
    participant Resolver as RecipeResolver
    participant QS as QuantizerStates
    participant Model as Model

    Caller->>Searcher: get_auto_quantize_config(search_state, constraints)
    Searcher->>Resolver: _resolve_best_recipe(search_state, constraints)
    Resolver->>Searcher: run_search_with_stats(...) -> best_recipe, stats
    alt checkpoint contains quantizer states
        Resolver->>QS: restore/collect per-quantizer state
        QS-->>Searcher: quantizer_states
    end
    Searcher->>Model: set_quantizer_by_cfg(best_recipe.quant_cfg, name_to_module)
    Model-->>Caller: quantized model / flat quant_cfg
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.94% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Auto Quantize improvements and bug fixes for large sparse MoEs' accurately summarizes the main changes in this comprehensive PR that adds new features (get_auto_quantize_config API, per-rank distributed checkpointing, calibration state caching) and bug fixes (FP8 NaN/inf handling, disabled_layers support, O(N^2) overhead reduction) for Auto Quantize, particularly targeting NemotronH and SparseMoE models.
Security Anti-Patterns ✅ Passed The PR has been reviewed against all six security anti-patterns in SECURITY.md. New torch.load calls include proper security comments for internally-generated checkpoints, and no instances of hardcoded trust_remote_code=True, numpy.load with allow_pickle, eval/exec on user input, nosec comments, or undeclared non-permissive dependencies were found.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch asma/nemotron_mixed

Comment @coderabbitai help to get the list of available commands and usage tips.

Comment thread modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/algorithms.py`:
- Around line 1314-1320: The re-solve branch silently falls back to
AutoQuantizeKLDivSearcher when search_state["method"] is anything other than
"gradient", which can mask malformed/old states; update the logic around
search_state["method"] so you explicitly handle valid values (e.g., "gradient"
-> AutoQuantizeGradientSearcher, "kldiv" or a documented name ->
AutoQuantizeKLDivSearcher) and otherwise raise a clear error or log+raise
indicating an unknown/missing method from the stored search_state; keep
assigning searcher.candidate_stats = candidate_stats after constructing the
validated searcher (refer to symbols search_state, method,
AutoQuantizeGradientSearcher, AutoQuantizeKLDivSearcher, and candidate_stats).
- Around line 596-597: The code unconditionally sets self.method =
self.method_name after restoring a checkpoint, which can hide a mismatch between
the checkpoint's scoring method and the current method and cause stale
candidate_stats to be reused; update the restore logic in the class (where
self.method, self.method_name and candidate_stats are handled) to compare the
restored method value from the checkpoint with self.method_name and if they
differ, clear or invalidate candidate_stats (and any cached scoring state) and
log or mark that a re-score is required before accepting any restored stats;
ensure the comparison uses the exact symbol names self.method, self.method_name
and candidate_stats so the guard runs immediately after checkpoint load and
before any code that would skip scoring.

In `@modelopt/torch/quantization/model_quant.py`:
- Around line 234-241: The current guard using is_quantized(model) prevents
re-applying a new QuantizeConfig to an already-quantized model, so remove or
bypass that gate and always call apply_mode with the new config_dict (using
QuantizeConfig(**config_dict) when needed) so the model receives the updated
quantization configuration; specifically ensure the flow around QuantizeConfig,
config_dict, is_quantized, apply_mode(..., mode=[("quantize", config_dict)],
registry=QuantizeModeRegistry) and subsequent calibrate(...) always executes
apply_mode to install the new quant_cfg even if is_quantized(model) returns
True.

In `@modelopt/torch/quantization/plugins/huggingface.py`:
- Around line 1199-1207: The loop is forcing recursive mode changes and ignores
the gradient_checkpointing value; replace calls to m.train()/m.eval() with a
non-recursive assignment of m.training = True when getattr(m,
"gradient_checkpointing", False) is truthy, otherwise m.training = False, so use
getattr(m, "gradient_checkpointing", False) to check the flag value and set the
module's training boolean directly (avoid m.train()/m.eval() calls) to prevent
recursive propagation over child modules.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ba29ad7 and bf70a7a0d0cb97da7f105f26c55630a9455b8b23.

📒 Files selected for processing (9)
  • modelopt/torch/quantization/algorithms.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/model_quant.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu
  • modelopt/torch/quantization/tensor_quant.py
  • modelopt/torch/quantization/utils.py
  • tests/gpu/torch/quantization/test_tensor_quant_cuda.py
  • tests/unit/torch/quantization/test_autoquant.py

Comment thread modelopt/torch/quantization/algorithms.py Outdated
Comment thread modelopt/torch/quantization/algorithms.py
Comment thread modelopt/torch/quantization/model_quant.py Outdated
Comment thread modelopt/torch/quantization/plugins/huggingface.py
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 2, 2026

Codecov Report

❌ Patch coverage is 89.86486% with 15 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.22%. Comparing base (fff65b0) to head (84589b8).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/tensor_quant.py 0.00% 6 Missing ⚠️
modelopt/torch/opt/searcher.py 77.27% 5 Missing ⚠️
modelopt/torch/quantization/algorithms.py 95.78% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #953      +/-   ##
==========================================
+ Coverage   70.11%   70.22%   +0.11%     
==========================================
  Files         220      220              
  Lines       25240    25342     +102     
==========================================
+ Hits        17697    17797     +100     
- Misses       7543     7545       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@realAsma realAsma force-pushed the asma/nemotron_mixed branch 4 times, most recently from 46b685d to 558c17c Compare March 4, 2026 17:33
@realAsma realAsma requested a review from a team as a code owner March 4, 2026 17:33
@realAsma realAsma requested a review from ChenhanYu March 4, 2026 17:33
@realAsma realAsma force-pushed the asma/nemotron_mixed branch from 558c17c to 2729ed6 Compare March 6, 2026 19:11
@cjluo-nv cjluo-nv requested a review from Copilot March 9, 2026 16:48
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enhances the ModelOpt Torch quantization “AutoQuantize” workflow to better support large sparse MoE models (e.g., NemotronH), splitting “search” vs “final quantization” responsibilities and improving performance/resume behavior for large models.

Changes:

  • Added mtq.get_auto_quantize_config(search_state, ...) to export a quantize-ready config (optionally re-solving constraints) from an AutoQuantize search state.
  • Improved AutoQuantize scalability/resume behavior (calibration-state caching in checkpoints; per-rank checkpoint files under torch.distributed; reduced repeated module scans via name_to_module plumbing).
  • Fixed/adjusted several quantization behaviors: disabled-layer handling, FP8 zero/near-zero amax stability, KL-div scoring implementation, and “quantize twice” behavior.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/unit/torch/quantization/test_quantize_cpu.py Adds coverage for calling mtq.quantize twice on the same model.
tests/unit/torch/quantization/test_autoquant.py Adds tests for disabled-layer behavior, checkpoint resume (including calibration cache), and config export via get_auto_quantize_config.
tests/gpu/torch/quantization/test_tensor_quant_cuda.py Adds FP8 fake-quant tests to ensure finite outputs when amax is zero (scalar and per-channel).
modelopt/torch/quantization/utils.py Documents name_to_module usage for performance when enabling weight access/writeback.
modelopt/torch/quantization/tensor_quant.py Makes eager FP8 fake quant robust to tiny/zero amax and clamps scaled values for stability.
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu Prevents FP8 scale blow-ups by guarding against tiny/zero amax in CUDA FP8 fake quant.
modelopt/torch/quantization/plugins/huggingface.py Adjusts gradient-checkpointing setup to keep non-checkpointed modules in eval (to avoid fused paths that bypass quantized modules).
modelopt/torch/quantization/model_quant.py Updates quantize() to support re-quantizing already-quantized models; exports get_auto_quantize_config; restricts auto_quantize to supported algorithms.
modelopt/torch/quantization/model_calib.py Threads name_to_module into weight access/writeback contexts to reduce repeated named-module walks.
modelopt/torch/quantization/algorithms.py Adds calibration caching into AutoQuantize checkpoints, fixes KL-div scoring to use F.kl_div, supports NemotronH/MoE grouping rules, and implements get_auto_quantize_config(...).
modelopt/torch/opt/searcher.py Implements per-rank checkpoint paths when distributed; adds strict/non-strict checkpoint loading for backward compatibility.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread modelopt/torch/quantization/algorithms.py Outdated
Comment thread tests/unit/torch/quantization/test_quantize_cpu.py Outdated
@realAsma realAsma force-pushed the asma/nemotron_mixed branch 2 times, most recently from 6f24db0 to ae1b26b Compare March 9, 2026 20:14
Comment thread modelopt/torch/quantization/algorithms.py
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review: PR #953

Good PR overall — the KL divergence fix, the disabled_layers poisoning bug fix, and the O(N²) convert_to_single_quantizer fix are all solid. A few design and readability concerns below.

Design

1. _cfg_to_dict includes both explicit and model_dump(exclude_defaults=True) fields — potential duplication

In get_auto_quantize_config, _cfg_to_dict does:

return {
    "enable": v.enable,
    "num_bits": v.num_bits,
    **v.model_dump(exclude_defaults=True),
}

If enable or num_bits are non-default, they'll appear in both the explicit dict and model_dump(). This works because the ** spread overwrites, but it's fragile — if model_dump ever changes key names, you'd get stale keys. Consider just using v.model_dump() (without exclude_defaults) or explicitly listing only the needed fields.

2. _match_quantizer_cfg uses last-match-wins semantics silently

def _match_quantizer_cfg(quant_cfg, quantizer_attr):
    matched = None
    for pattern, cfg in quant_cfg.items():
        if fnmatch.fnmatch(quantizer_attr, pattern):
            matched = cfg
    return matched

This iterates all patterns and returns the last match. This is fine if it mirrors set_quantizer_by_cfg behavior, but worth a brief comment confirming that last-match-wins is intentional.

3. restore_quantizer_state called with dummy QuantizeConfig() — API smell

restore_quantizer_state(self.model, QuantizeConfig(), {"quantizer_state": saved["metadata"]})

The comment says "config is unused", and same for update_quantize_metadata. If the config parameter is truly unused, this is a sign that these functions' APIs should be refactored to not require it. Not blocking, but worth tracking.

4. Per-rank checkpoint: silently falls back to single-file checkpoint

In load_search_checkpoint, if the per-rank file doesn't exist, it falls back to the original single-file path. This could mask bugs where a user accidentally changes world size. The warning message mentions maintaining parallelism config, but only on save — consider also warning on the fallback path.

Readability

5. method_name class attribute typed as str = None is misleading

class _AutoQuantizeBaseSearcher(...):
    method_name: str = None

This type annotation says str but the default is None. Should be str | None = None or Optional[str] = None for clarity.

6. before_search() is getting long and doing many things

The before_search method now handles: checkpoint method validation, recipe iteration, calibration state restore, calibration execution, calibration state save, and score estimation. Consider extracting the calibration save/restore logic into a _calibrate_or_restore(recipe) helper to improve readability.

Potential Bugs

7. FP8 zero-amax guard: clamp only added in Python eager path, not in CUDA kernel

In tensor_quant.py, the Python eager path adds .clamp(min=-448.0, max=448.0) after scaling:

x = (x.to(torch.float32) * scale).clamp(min=-448.0, max=448.0)

But the CUDA kernels (fake_e4m3fy, fake_e4m3fy_with_axis) only guard the amax→scale computation and don't clamp the scaled values. When zero_amax_mask is True, safe_amax = 1.0 so scale = 448.0, and x * 448.0 could exceed FP8 range for inputs > 1.0. The CUDA kernel relies on the FP8 cast to clamp, which should be fine for torch.float8_e4m3fn (which saturates), but worth verifying this is consistent between the two paths.

8. _resolve_best_recipe instantiates a searcher but doesn't set self.model

searcher = AutoQuantizeGradientSearcher()
searcher.candidate_stats = candidate_stats
best_recipe_info, _ = searcher.run_search_with_stats(max_weight_size, verbose=verbose)

This works because run_search_with_stats only uses self.candidate_stats, not self.model. But it's fragile — if run_search_with_stats ever accesses self.model or self.config, this will crash. A class method or static function on the searcher (e.g., AutoQuantizeGradientSearcher.solve(candidate_stats, max_weight_size)) would be safer.

@cjluo-nv
Copy link
Copy Markdown
Collaborator

Alternative Design Proposal for PR #953

After reading the full codebase context, here are concrete alternative designs for the main concerns. The goal is to keep the same functionality while improving separation of concerns, testability, and robustness.


1. Extract calibration cache into its own class

The before_search() method is growing because it now manages calibration state save/restore alongside the search logic. This can be cleanly separated:

class _CalibrationCache:
    """Manages save/restore of per-recipe quantizer calibration states."""

    def __init__(self, initial_states: dict | None = None):
        self._states: dict[QuantRecipe, dict] = dict(initial_states or {})

    def has(self, recipe: QuantRecipe) -> bool:
        return recipe in self._states

    def restore(self, model: nn.Module, recipe: QuantRecipe) -> None:
        """Restore calibration state for a recipe onto the model."""
        saved = self._states[recipe]
        restore_quantizer_state(model, QuantizeConfig(), {"quantizer_state": saved["metadata"]})
        set_quantizer_state_dict(model, saved["state_dict"])

    def save(self, model: nn.Module, recipe: QuantRecipe) -> None:
        """Capture current calibration state from the model for a recipe."""
        metadata: dict = {}
        update_quantize_metadata(model, QuantizeConfig(), metadata)
        self._states[recipe] = {
            "metadata": metadata["quantizer_state"],
            "state_dict": get_quantizer_state_dict(model),
        }

    def to_dict(self) -> dict:
        return dict(self._states)

    @classmethod
    def from_dict(cls, d: dict) -> "_CalibrationCache":
        return cls(initial_states=d)

Then before_search() becomes:

def before_search(self):
    super().before_search()
    self._validate_checkpoint_method()

    calib_cache = _CalibrationCache(self.quantizer_states)
    calibrated_new = False

    for recipe in search_recipes:
        if recipe == QuantRecipe(quant_cfg=None):
            continue
        self._activate_recipe(recipe)
        if calib_cache.has(recipe):
            calib_cache.restore(self.model, recipe)
        else:
            self._run_calibration(recipe)
            calib_cache.save(self.model, recipe)
            calibrated_new = True

    self.quantizer_states = calib_cache.to_dict()
    if calibrated_new:
        self.save_search_checkpoint(verbose=self.config["verbose"])
    # ... rest of scoring logic

This makes the calibration caching independently testable and keeps before_search focused on orchestration.


2. Make the solver a standalone function to avoid fragile partial initialization

Instead of instantiating a searcher object without a model:

# Current (fragile):
searcher = AutoQuantizeGradientSearcher()
searcher.candidate_stats = candidate_stats
best_recipe_info, _ = searcher.run_search_with_stats(max_weight_size)

Extract the solver logic into standalone functions that both run_search and _resolve_best_recipe can call:

def _solve_lp(candidate_stats, max_weight_size, verbose=False):
    """Run LP solver on pre-computed candidate stats. Returns (best_recipes, is_satisfied)."""
    for lower_bound in [None, 0.99, 0.90]:
        # ... existing LP logic from AutoQuantizeGradientSearcher.run_search_with_stats ...
        pass
    return best_recipes, is_satisfied


def _solve_threshold(candidate_stats, max_weight_size, verbose=False):
    """Run threshold binary search on pre-computed candidate stats."""
    # ... existing binary search logic from AutoQuantizeKLDivSearcher.run_search_with_stats ...
    pass

Then _resolve_best_recipe becomes a clean dispatch:

_SOLVERS = {"gradient": _solve_lp, "kl_div": _solve_threshold}

def _resolve_best_recipe(search_state, constraints, verbose=False):
    method = search_state["method"]
    solver = _SOLVERS.get(method)
    if solver is None:
        raise ValueError(f"Unknown method: {method\!r}")
    # ... compute max_weight_size ...
    best_recipe_info, _ = solver(candidate_stats, max_weight_size, verbose)
    return {name: info["format"] for name, info in best_recipe_info.items()}

And the searcher classes call the same functions from run_search_with_stats. This eliminates the fragile partially-initialized object pattern and makes the solvers independently testable.


3. Simplify _cfg_to_dict

The current implementation has redundancy between explicit fields and model_dump. Using Python 3.9+ dict merge operator makes the intent clear — "always include enable/num_bits even if default, plus any non-default fields":

def _cfg_to_dict(v):
    if isinstance(v, QuantizerAttributeConfig):
        return v.model_dump(exclude_defaults=True) | {"enable": v.enable, "num_bits": v.num_bits}
    if isinstance(v, list):
        return [_cfg_to_dict(c) for c in v]
    return v

Or if compactness isn't a concern, just use v.model_dump() with no exclusions.


4. Per-rank checkpoint: warn on fallback to prevent silent world-size mismatch

def load_search_checkpoint(self, strict=True) -> bool:
    checkpoint = self._get_checkpoint_path()
    if checkpoint is None:
        return False

    if not os.path.exists(checkpoint):
        fallback = self.config["checkpoint"]
        if fallback and fallback \!= checkpoint and os.path.exists(fallback):
            warn_rank_0(
                f"Per-rank checkpoint {checkpoint} not found, falling back to {fallback}. "
                "This may indicate a world-size change between save and restore."
            )
            checkpoint = fallback
        else:
            warn_rank_0(f"Checkpoint {checkpoint} does not exist\! Initializing from scratch.")
            return False
    # ... rest of loading

Comment thread modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu
@realAsma
Copy link
Copy Markdown
Contributor Author

@cjluo-nv Thanks for the thorough review! Addressed points 2, 4, and 5 in e59d990. Replies to each:

1. _cfg_to_dict duplication — Both produce the same values from the same object. The explicit enable/num_bits keys provide clarity about what's always included, and model_dump(exclude_defaults=True) only adds extra fields like axis. The ** spread is deterministic here.

2. _match_quantizer_cfg last-match-wins — Added a comment. ✅

3. restore_quantizer_state dummy QuantizeConfig — Agreed it's an API smell. These functions were designed for the broader save/restore flow where config is used. Refactoring is out of scope for this PR but tracked.

4. Per-rank checkpoint fallback — Added a warning on the fallback path. ✅

5. method_name type annotation — Fixed to str | None = None. ✅

6. before_search() long — The method is sequential and each section is logically distinct. Extracting a helper for one call site wouldn't meaningfully improve readability.

7. FP8 CUDA vs Python clamp — Good observation. The CUDA path relies on float8_e4m3fn saturation (which clamps to ±448 on cast), while Python uses explicit .clamp(). Both produce the same result — the CUDA kernel's FP8 cast is the clamp.

8. _resolve_best_recipe fragilityrun_search_with_stats is a well-defined method that only uses self.candidate_stats. Adding a static method for this one helper is over-engineering for now.

@realAsma realAsma requested review from cjluo-nv and meenchen March 10, 2026 18:15
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! All critical and important review comments have been addressed. A few minor observations for future consideration:

  • Typo: config_resoled in test_get_auto_quantize_config should be config_resolved
  • Gradient checkpointing logic in huggingface.py relies on hasattr(m, "gradient_checkpointing") which is an undocumented HF internal — worth a comment noting this dependency
  • Per-rank distributed checkpointing fallback path (when per-rank file is missing) doesn't have test coverage yet

None of these are blocking. The core fixes (FP8 zero-amax, O(N^2) elimination, disabled_layers poisoning, KL divergence refactoring) are solid and well-tested.

Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, left some questions

Comment thread modelopt/torch/quantization/algorithms.py
realAsma and others added 2 commits March 10, 2026 19:25
- Add get_auto_quantize_config API to extract quant config from search results
- Save/restore calibration state in auto_quantize checkpoints
- Add NemotronH MoE expert support in auto_quantize grouping/scoring
- Fix SequentialQuantizer scope, use F.kl_div for numerical stability
- Fix mypy errors and clean up tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
@realAsma realAsma force-pushed the asma/nemotron_mixed branch from e59d990 to c2cb8ec Compare March 10, 2026 19:26
@realAsma realAsma enabled auto-merge (squash) March 10, 2026 19:27
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
modelopt/torch/opt/searcher.py (1)

275-292: Consider making the distributed warning conditional or one-time.

The warning on lines 282-285 is emitted on every save when torch.distributed is initialized. For long-running searches with frequent checkpointing, this could be noisy.

Consider either:

  1. Emitting the warning only once (e.g., using a class-level flag)
  2. Making it conditional on verbose mode

However, this is a minor UX concern and the current behavior ensures users are always reminded of the parallelism constraint.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/opt/searcher.py` around lines 275 - 292, The distributed
warning in save_search_checkpoint is noisy because it's logged on every save
when dist.is_initialized(); modify save_search_checkpoint (or the class) to emit
this warning only once by introducing a flag (e.g., a class-level or instance
attribute like _warned_dist) that you check before calling warn_rank_0 and set
to True after the first warning, or alternatively make the warn_rank_0 call
conditional on the verbose parameter so it only prints when verbose is True;
update the logic around dist.is_initialized() in save_search_checkpoint to use
that flag/condition and initialize the flag on the class or in __init__.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/algorithms.py`:
- Around line 390-391: The early return from load_search_checkpoint (which calls
super().load_search_checkpoint(strict=False)) allows legacy candidate_stats
without "module_names" to pass through and later break get_auto_quantize_config
when it expects candidate_stats[hparam_name]["module_names"]; modify
load_search_checkpoint to detect legacy candidate_stats in the resumed
search_state (inspect search_state.candidate_stats or candidate_stats per
hparam_name), backfill a default module_names list (e.g., derived from existing
keys or an empty list) for each candidate_stats entry before returning, and
ensure this backfill logic is applied in the same method where
load_search_checkpoint is defined so subsequent calls to
get_auto_quantize_config see the expected "module_names" field.
- Around line 1302-1311: When flattening per-module quantizer settings, honor a
recipe's "default" entry if no glob matcher matches: after calling
_match_quantizer_cfg(recipe.config.quant_cfg, quantizer_attr) (inside the loop
over best_recipe, module_names and quantizer_attr), if matched_cfg is None then
check recipe.config.quant_cfg.get("default") and use that value as matched_cfg
before assigning quant_cfg[f"{module_name}.{quantizer_attr}"] = matched_cfg; do
this for all occurrences (the shown block and the similar block around lines
~1358-1364).
- Around line 1284-1325: get_auto_quantize_config currently always returns
{"algorithm": "max"} which silently discards per-recipe algorithm choices;
update it to inspect the algorithm/type of each selected recipe in best_recipe
(use the QuantRecipe instances in search_state["best"]["recipe"] or from
_resolve_best_recipe) and if all selected recipes share the same algorithm
string (e.g., "smoothquant", "awq_*", "local_hessian", "max") return that
algorithm value instead of "max"; if recipes across layers use mixed algorithms
that cannot be represented by a single flat QuantizeConfig, raise a clear
error/ValueError indicating mixed algorithms so callers must handle/export a
mixed-recipe representation. Ensure the detection runs before building quant_cfg
and reference get_auto_quantize_config, best_recipe, and QuantRecipe when
implementing the checks.

---

Nitpick comments:
In `@modelopt/torch/opt/searcher.py`:
- Around line 275-292: The distributed warning in save_search_checkpoint is
noisy because it's logged on every save when dist.is_initialized(); modify
save_search_checkpoint (or the class) to emit this warning only once by
introducing a flag (e.g., a class-level or instance attribute like _warned_dist)
that you check before calling warn_rank_0 and set to True after the first
warning, or alternatively make the warn_rank_0 call conditional on the verbose
parameter so it only prints when verbose is True; update the logic around
dist.is_initialized() in save_search_checkpoint to use that flag/condition and
initialize the flag on the class or in __init__.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 402459bd-3697-4c75-9ac6-304cb03a17dd

📥 Commits

Reviewing files that changed from the base of the PR and between bf70a7a0d0cb97da7f105f26c55630a9455b8b23 and c2cb8ec.

📒 Files selected for processing (12)
  • CHANGELOG.rst
  • modelopt/torch/opt/searcher.py
  • modelopt/torch/quantization/algorithms.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/model_quant.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu
  • modelopt/torch/quantization/tensor_quant.py
  • modelopt/torch/quantization/utils.py
  • tests/gpu/torch/quantization/test_tensor_quant_cuda.py
  • tests/unit/torch/quantization/test_autoquant.py
  • tests/unit/torch/quantization/test_quantize_cpu.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/plugins/huggingface.py

Comment thread modelopt/torch/quantization/algorithms.py
Comment thread modelopt/torch/quantization/algorithms.py
Comment thread modelopt/torch/quantization/algorithms.py
realAsma and others added 2 commits March 10, 2026 20:05
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Copy link
Copy Markdown
Collaborator

@kevalmorabia97 kevalmorabia97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving as codeowner. Only reviewed changes in modelopt/torch/opt

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
modelopt/torch/quantization/algorithms.py (1)

1313-1322: Minor: _cfg_to_dict includes redundant keys.

model_dump(exclude_defaults=True) may already include enable and num_bits when they differ from defaults, making the explicit additions potentially redundant. However, the explicit inclusion ensures determinism and clarity as the author noted in PR comments.

Consider simplifying (optional)

If you want to avoid potential duplicate keys in the output:

     def _cfg_to_dict(v):
         if isinstance(v, mtq_config.QuantizerAttributeConfig):
-            return {
-                "enable": v.enable,
-                "num_bits": v.num_bits,
-                **v.model_dump(exclude_defaults=True),
-            }
+            d = v.model_dump(exclude_defaults=True)
+            # Ensure enable and num_bits are always present
+            d.setdefault("enable", v.enable)
+            d.setdefault("num_bits", v.num_bits)
+            return d
         if isinstance(v, list):
             return [_cfg_to_dict(c) for c in v]
         return v

This is a nitpick - the current approach works correctly and the author explicitly chose this for clarity.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/algorithms.py` around lines 1313 - 1322, The
helper _cfg_to_dict currently builds the dict by placing explicit "enable" and
"num_bits" keys before merging model_dump(exclude_defaults=True), which can
produce redundant keys if model_dump also contains them; to make the result
deterministic and avoid duplicate key confusion, change the merge order in
_cfg_to_dict (use {**v.model_dump(exclude_defaults=True), "enable": v.enable,
"num_bits": v.num_bits}) so that the explicit fields override any values from
model_dump, and keep the list handling logic (_cfg_to_dict for lists) unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/quantization/algorithms.py`:
- Around line 1313-1322: The helper _cfg_to_dict currently builds the dict by
placing explicit "enable" and "num_bits" keys before merging
model_dump(exclude_defaults=True), which can produce redundant keys if
model_dump also contains them; to make the result deterministic and avoid
duplicate key confusion, change the merge order in _cfg_to_dict (use
{**v.model_dump(exclude_defaults=True), "enable": v.enable, "num_bits":
v.num_bits}) so that the explicit fields override any values from model_dump,
and keep the list handling logic (_cfg_to_dict for lists) unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 69a18ea4-95a0-4e72-a5e4-f1f960c6d8f0

📥 Commits

Reviewing files that changed from the base of the PR and between c2cb8ec and 67b2dc2.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/algorithms.py

@realAsma realAsma merged commit a5d46ff into main Mar 10, 2026
39 checks passed
@realAsma realAsma deleted the asma/nemotron_mixed branch March 10, 2026 20:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants