Skip to content

Commit 229ba61

Browse files
authored
[OMNIML-3050] Enable torch.compile on _get_log_softmax_dist (#1479)
### What does this PR do? Type of change: new feature Wraps `_get_log_softmax_dist` (`modelopt/torch/quantization/algorithms.py`) — the distributed log-softmax helper used by `AutoQuantizeKLDivSearcher` under TP > 1 — with `@torch.compile(dynamic=True)`. Inductor fuses the `amax / all_reduce(MAX) / logsumexp / all_reduce(SUM) / log+sub+cast` pipeline into one kernel. `dynamic=True` avoids recompiles across the varying `[batch, seq]` shapes seen during calibration, matching the existing pattern in `backends/fp8_per_tensor_gemm.py`. Stale TODOs are removed; the prior ONNX-Windows concern no longer applies because the function is only reachable when a TP group is initialized (never on the Windows CPU unit-test job), and `@torch.compile` is import-time safe. ### Usage ```python # Internal — invoked automatically by AutoQuantize KL-divergence search under TP > 1: import modelopt.torch.quantization as mtq model, _ = mtq.auto_quantize( model, constraints={"effective_bits": 6.0}, quantization_formats=[mtq.INT4_AWQ_CFG, mtq.INT8_DEFAULT_CFG], data_loader=calib_loader, forward_step=lambda m, b: m(b), method="kl_div", ) ``` ### Testing Verified locally against the affected unit-test paths: - `pytest tests/unit/torch/quantization/test_autoquant.py -k kl_div` → 22 passed (covers the call chain into `_get_log_prob`). - `pytest tests/unit/onnx/` → 516 passed. The single failure (`test_autocast_quantize.py::test_autocast_quantize_int8[False-True]` — onnxruntime `CopyTensorAsync is not implemented`) is a pre-existing failure on `main`, unrelated to this change (confirmed by stashing the diff and re-running). - `pytest tests/unit/torch/quantization/test_autoquant.py tests/unit/torch/quantization/test_quantize_cpu.py tests/unit/torch/quantization/test_config_validation.py` → 159 passed. - Functional sanity in a single-process gloo group: compiled output matches `torch.log_softmax` reference for fp32 and fp16, with no recompiles across varying `[batch, seq]` shapes. - `pre-commit run --files modelopt/torch/quantization/algorithms.py` → ruff / mypy / bandit all pass. ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (\`git commit -s -S\`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded \`trust_remote_code=True\`, \`torch.load(..., weights_only=False)\`, \`pickle\`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in \`CONTRIBUTING.md\`: N/A - Did you write any new necessary tests?: N/A — existing \`kl_div\` autoquant tests already exercise the call chain. - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A - Did you get Claude approval on this PR?: ❌ — will run \`/claude review\` after marking ready. ### Additional Information Single-line behavior change: \`_get_log_softmax_dist\` is now \`@torch.compile(dynamic=True)\`-decorated. No API surface changes. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 94337ad commit 229ba61

1 file changed

Lines changed: 1 addition & 3 deletions

File tree

modelopt/torch/quantization/algorithms.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,10 +1094,8 @@ def run_search_with_stats(self, max_weight_size, verbose=False):
10941094
return best_recipes, is_satisfied
10951095

10961096

1097-
# TODO: Enable torch compile for this function
1098-
# Currently modelopt.onnx is breaking this
1097+
@torch.compile(dynamic=True)
10991098
def _get_log_softmax_dist(logits: torch.Tensor, tp_group) -> torch.Tensor:
1100-
# TODO: test this
11011099
dtype = logits.dtype
11021100
max_logits = torch.amax(logits, dim=-1, keepdim=True)
11031101
torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=tp_group)

0 commit comments

Comments
 (0)