Skip to content

Commit 8a0adab

Browse files
committed
Do not expicitly decide which path will be used in calibration
Signed-off-by: Andrzej Kotłowski <andrzej.kotlowski@intel.com>
1 parent 614e397 commit 8a0adab

1 file changed

Lines changed: 28 additions & 9 deletions

File tree

neural_compressor/jax/quantization/layers_static.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,14 @@ def get_calibrated_range(self):
144144
"""
145145
return ops.array((self.min_val, self.max_val))
146146

147+
def is_calibrated(self):
148+
"""Check if the observer has valid calibration data.
149+
150+
Returns:
151+
bool: True if calibrated, False if min (and interrnally max) are still at initial values.
152+
"""
153+
return not (jnp.isinf(self.min_val.value).any())
154+
147155

148156
class StaticQDQLayer(SaveableLayerMixin, keras.layers.Layer):
149157
"""Layer that applies static quantize-dequantize to activations."""
@@ -652,12 +660,26 @@ def convert(self):
652660
Returns:
653661
None: Updates QDQ helpers with calibrated values.
654662
"""
655-
self.f_qdq.convert()
656-
# Calculate the scale for query for dot product attention path
657-
# from the fallback path used in calibration
658-
self.q_qdq.a_scale.assign(self.f_qdq.a_scale / self._inverse_sqrt_key_dim)
659-
if self.q_qdq._is_asymmetric:
660-
self.q_qdq.a_zero_point.assign(jnp.array(self.f_qdq.a_zero_point.value))
663+
if self.q_qdq.input_observer.is_calibrated():
664+
self.q_qdq.convert()
665+
if not self.f_qdq.input_observer.is_calibrated():
666+
# Calculate the scale for query in the fallback path
667+
# from the dot product attention path used in calibration
668+
self.f_qdq.a_scale.assign(self.q_qdq.a_scale * self._inverse_sqrt_key_dim)
669+
if self.f_qdq._is_asymmetric:
670+
self.f_qdq.a_zero_point.assign(jnp.array(self.q_qdq.a_zero_point.value))
671+
else:
672+
self.f_qdq.convert()
673+
else:
674+
self.f_qdq.convert()
675+
if not self.q_qdq.input_observer.is_calibrated():
676+
# Calculate the scale for query for dot product attention path
677+
# from the fallback path used in calibration
678+
self.q_qdq.a_scale.assign(self.f_qdq.a_scale / self._inverse_sqrt_key_dim)
679+
if self.q_qdq._is_asymmetric:
680+
self.q_qdq.a_zero_point.assign(jnp.array(self.f_qdq.a_zero_point.value))
681+
else:
682+
self.q_qdq.convert()
661683
self.k_qdq.convert()
662684
self.a_qdq.convert()
663685
self.v_qdq.convert()
@@ -721,9 +743,6 @@ def _compute_attention(
721743
or return_attention_scores
722744
or (len(query.shape) != 4)
723745
)
724-
# For calibration always use fallback path as it can collect data for both paths
725-
use_dot_product_attention = use_dot_product_attention and self._is_quantized
726-
use_dot_product_attention = False
727746

728747
if use_dot_product_attention:
729748
if attention_mask is not None:

0 commit comments

Comments
 (0)