@@ -144,6 +144,14 @@ def get_calibrated_range(self):
144144 """
145145 return ops .array ((self .min_val , self .max_val ))
146146
147+ def is_calibrated (self ):
148+ """Check if the observer has valid calibration data.
149+
150+ Returns:
151+ bool: True if calibrated, False if min (and interrnally max) are still at initial values.
152+ """
153+ return not (jnp .isinf (self .min_val .value ).any ())
154+
147155
148156class StaticQDQLayer (SaveableLayerMixin , keras .layers .Layer ):
149157 """Layer that applies static quantize-dequantize to activations."""
@@ -652,12 +660,26 @@ def convert(self):
652660 Returns:
653661 None: Updates QDQ helpers with calibrated values.
654662 """
655- self .f_qdq .convert ()
656- # Calculate the scale for query for dot product attention path
657- # from the fallback path used in calibration
658- self .q_qdq .a_scale .assign (self .f_qdq .a_scale / self ._inverse_sqrt_key_dim )
659- if self .q_qdq ._is_asymmetric :
660- self .q_qdq .a_zero_point .assign (jnp .array (self .f_qdq .a_zero_point .value ))
663+ if self .q_qdq .input_observer .is_calibrated ():
664+ self .q_qdq .convert ()
665+ if not self .f_qdq .input_observer .is_calibrated ():
666+ # Calculate the scale for query in the fallback path
667+ # from the dot product attention path used in calibration
668+ self .f_qdq .a_scale .assign (self .q_qdq .a_scale * self ._inverse_sqrt_key_dim )
669+ if self .f_qdq ._is_asymmetric :
670+ self .f_qdq .a_zero_point .assign (jnp .array (self .q_qdq .a_zero_point .value ))
671+ else :
672+ self .f_qdq .convert ()
673+ else :
674+ self .f_qdq .convert ()
675+ if not self .q_qdq .input_observer .is_calibrated ():
676+ # Calculate the scale for query for dot product attention path
677+ # from the fallback path used in calibration
678+ self .q_qdq .a_scale .assign (self .f_qdq .a_scale / self ._inverse_sqrt_key_dim )
679+ if self .q_qdq ._is_asymmetric :
680+ self .q_qdq .a_zero_point .assign (jnp .array (self .f_qdq .a_zero_point .value ))
681+ else :
682+ self .q_qdq .convert ()
661683 self .k_qdq .convert ()
662684 self .a_qdq .convert ()
663685 self .v_qdq .convert ()
@@ -721,9 +743,6 @@ def _compute_attention(
721743 or return_attention_scores
722744 or (len (query .shape ) != 4 )
723745 )
724- # For calibration always use fallback path as it can collect data for both paths
725- use_dot_product_attention = use_dot_product_attention and self ._is_quantized
726- use_dot_product_attention = False
727746
728747 if use_dot_product_attention :
729748 if attention_mask is not None :
0 commit comments