Skip to content

[JAX] Use const scale in MHA after softmax#2466

Merged
anko-intel merged 4 commits into
masterfrom
dev/anko/know_scale
May 12, 2026
Merged

[JAX] Use const scale in MHA after softmax#2466
anko-intel merged 4 commits into
masterfrom
dev/anko/know_scale

Conversation

@anko-intel
Copy link
Copy Markdown
Contributor

Type of Change

feature

Description

For tensor quantized inside MultiHeadAttention after softmax const range of values can be assumed. This way calibration or min.max finding for dynamic quantization is not required for this tensor.

Expected Behavior & Potential Risk

the expected behavior that triggered by this PR

How has this PR been tested?

how to reproduce the test (including hardware information)

Dependency Change?

any library dependency introduced or removed

For tensor quantized after const softmax const range of values can be
assumed.

Signed-off-by: Andrzej Kotłowski <andrzej.kotlowski@intel.com>
Signed-off-by: Andrzej Kotłowski <andrzej.kotlowski@intel.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates JAX quantization support for MultiHeadAttention by treating the post-softmax attention probabilities as having a known fixed range, avoiding calibration/min-max collection for that tensor.

Changes:

  • Added a calibration-status helper (MinMaxObserver.is_calibrated()) to detect whether observer stats were populated.
  • Introduced fixed_range support in both static and dynamic QDQ helper layers, enabling scale computation without observers/per-batch min/max.
  • Applied fixed_range=(0.0, 1.0) to the post-softmax attention tensor (a_qdq) in both static and dynamic MultiHeadAttention quantization paths.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
neural_compressor/jax/quantization/layers_static.py Adds fixed_range support to static QDQ, introduces is_calibrated(), and uses fixed [0,1] range for MHA attention probabilities post-softmax.
neural_compressor/jax/quantization/layers_dynamic.py Adds fixed_range support to dynamic QDQ by precomputing scale/zero-point and uses fixed [0,1] range for MHA attention probabilities post-softmax.

Comment thread neural_compressor/jax/quantization/layers_static.py Outdated
Comment thread neural_compressor/jax/quantization/layers_static.py
Signed-off-by: Andrzej Kotłowski <andrzej.kotlowski@intel.com>
Comment thread neural_compressor/jax/quantization/layers_dynamic.py
Comment thread neural_compressor/jax/quantization/layers_dynamic.py
Signed-off-by: Andrzej Kotłowski <andrzej.kotlowski@intel.com>
@anko-intel anko-intel requested a review from bkowalskiINTEL May 12, 2026 14:40
Copy link
Copy Markdown
Contributor

@bkowalskiINTEL bkowalskiINTEL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@anko-intel anko-intel merged commit ec84358 into master May 12, 2026
14 checks passed
@anko-intel anko-intel deleted the dev/anko/know_scale branch May 12, 2026 14:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants