Commit 229ba61
authored
[OMNIML-3050] Enable torch.compile on _get_log_softmax_dist (#1479)
### What does this PR do?
Type of change: new feature
Wraps `_get_log_softmax_dist`
(`modelopt/torch/quantization/algorithms.py`) — the distributed
log-softmax helper used by `AutoQuantizeKLDivSearcher` under TP > 1 —
with `@torch.compile(dynamic=True)`. Inductor fuses the `amax /
all_reduce(MAX) / logsumexp / all_reduce(SUM) / log+sub+cast` pipeline
into one kernel. `dynamic=True` avoids recompiles across the varying
`[batch, seq]` shapes seen during calibration, matching the existing
pattern in `backends/fp8_per_tensor_gemm.py`. Stale TODOs are removed;
the prior ONNX-Windows concern no longer applies because the function is
only reachable when a TP group is initialized (never on the Windows CPU
unit-test job), and `@torch.compile` is import-time safe.
### Usage
```python
# Internal — invoked automatically by AutoQuantize KL-divergence search under TP > 1:
import modelopt.torch.quantization as mtq
model, _ = mtq.auto_quantize(
model,
constraints={"effective_bits": 6.0},
quantization_formats=[mtq.INT4_AWQ_CFG, mtq.INT8_DEFAULT_CFG],
data_loader=calib_loader,
forward_step=lambda m, b: m(b),
method="kl_div",
)
```
### Testing
Verified locally against the affected unit-test paths:
- `pytest tests/unit/torch/quantization/test_autoquant.py -k kl_div` →
22 passed (covers the call chain into `_get_log_prob`).
- `pytest tests/unit/onnx/` → 516 passed. The single failure
(`test_autocast_quantize.py::test_autocast_quantize_int8[False-True]` —
onnxruntime `CopyTensorAsync is not implemented`) is a pre-existing
failure on `main`, unrelated to this change (confirmed by stashing the
diff and re-running).
- `pytest tests/unit/torch/quantization/test_autoquant.py
tests/unit/torch/quantization/test_quantize_cpu.py
tests/unit/torch/quantization/test_config_validation.py` → 159 passed.
- Functional sanity in a single-process gloo group: compiled output
matches `torch.log_softmax` reference for fp32 and fp16, with no
recompiles across varying `[batch, seq]` shapes.
- `pre-commit run --files modelopt/torch/quantization/algorithms.py` →
ruff / mypy / bandit all pass.
### Before your PR is "*Ready for review*"
Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)
and your commits are signed (\`git commit -s -S\`).
Make sure you read and follow the [Security Best
Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors)
(e.g. avoiding hardcoded \`trust_remote_code=True\`, \`torch.load(...,
weights_only=False)\`, \`pickle\`, etc.).
- Is this change backward compatible?: ✅
- If you copied code from any other sources or added a new PIP
dependency, did you follow guidance in \`CONTRIBUTING.md\`: N/A
- Did you write any new necessary tests?: N/A — existing \`kl_div\`
autoquant tests already exercise the call chain.
- Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?:
N/A
- Did you get Claude approval on this PR?: ❌ — will run \`/claude
review\` after marking ready.
### Additional Information
Single-line behavior change: \`_get_log_softmax_dist\` is now
\`@torch.compile(dynamic=True)\`-decorated. No API surface changes.
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>1 parent 94337ad commit 229ba61
1 file changed
Lines changed: 1 addition & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1094 | 1094 | | |
1095 | 1095 | | |
1096 | 1096 | | |
1097 | | - | |
1098 | | - | |
| 1097 | + | |
1099 | 1098 | | |
1100 | | - | |
1101 | 1099 | | |
1102 | 1100 | | |
1103 | 1101 | | |
| |||
0 commit comments