[JAX] Use const scale in MHA after softmax#2466
Merged
Merged
Conversation
For tensor quantized after const softmax const range of values can be assumed. Signed-off-by: Andrzej Kotłowski <andrzej.kotlowski@intel.com>
Signed-off-by: Andrzej Kotłowski <andrzej.kotlowski@intel.com>
Contributor
There was a problem hiding this comment.
Pull request overview
This PR updates JAX quantization support for MultiHeadAttention by treating the post-softmax attention probabilities as having a known fixed range, avoiding calibration/min-max collection for that tensor.
Changes:
- Added a calibration-status helper (
MinMaxObserver.is_calibrated()) to detect whether observer stats were populated. - Introduced
fixed_rangesupport in both static and dynamic QDQ helper layers, enabling scale computation without observers/per-batch min/max. - Applied
fixed_range=(0.0, 1.0)to the post-softmax attention tensor (a_qdq) in both static and dynamicMultiHeadAttentionquantization paths.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
neural_compressor/jax/quantization/layers_static.py |
Adds fixed_range support to static QDQ, introduces is_calibrated(), and uses fixed [0,1] range for MHA attention probabilities post-softmax. |
neural_compressor/jax/quantization/layers_dynamic.py |
Adds fixed_range support to dynamic QDQ by precomputing scale/zero-point and uses fixed [0,1] range for MHA attention probabilities post-softmax. |
409f71c to
4e6ec44
Compare
Signed-off-by: Andrzej Kotłowski <andrzej.kotlowski@intel.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Type of Change
feature
Description
For tensor quantized inside MultiHeadAttention after softmax const range of values can be assumed. This way calibration or min.max finding for dynamic quantization is not required for this tensor.
Expected Behavior & Potential Risk
the expected behavior that triggered by this PR
How has this PR been tested?
how to reproduce the test (including hardware information)
Dependency Change?
any library dependency introduced or removed