Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 56 additions & 14 deletions neural_compressor/jax/quantization/layers_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,16 @@ def decorator(cls):
class DynamicQDQLayer(SaveableLayerMixin, keras.layers.Layer):
"""Layer that applies dynamic quantize-dequantize to activations."""

def __init__(self, name, activation_dtype, dtype="float32", asymmetric=False):
def __init__(self, name, activation_dtype, dtype="float32", asymmetric=False, fixed_range=None):
"""Initialize the dynamic QDQ helper layer.

Args:
name (str): Layer name.
activation_dtype (jnp.dtype): Activation dtype used for quantization.
dtype (str): dtype for the layer - see keras.layers.Layer API for details.
asymmetric (bool): Whether to use asymmetric quantization.
fixed_range (Optional[Tuple[float, float]]): If provided, use this (min, max) range
instead of computing min/max dynamically per batch.

Returns:
None: Initializes the layer instance.
Expand All @@ -84,46 +86,56 @@ def __init__(self, name, activation_dtype, dtype="float32", asymmetric=False):
self.activation_dtype = activation_dtype
self._is_asymmetric = asymmetric
self.supports_masking = True
self.fixed_range = fixed_range

def add_variables(self):
"""Create quantization helper functions for activations.

When fixed_range is set, pre-computes scale (and zero point) so that
call avoids per-batch min/max computation.

Returns:
None: Initializes quantization functions.
"""
self._tracker.unlock()
self.aquantfun = get_quantize_fun(dtype=self.activation_dtype, asymmetric=self._is_asymmetric)
self.adequantfun = get_dequantize_fun(dtype=self.compute_dtype, asymmetric=self._is_asymmetric)
if self.fixed_range is not None:
fixed_min_max = ops.array(self.fixed_range)
a_scale, a_zero_point = get_q_params(
fixed_min_max, self.activation_dtype, self.compute_dtype, asymmetric=self._is_asymmetric
)
self._fixed_a_scale = a_scale
if self._is_asymmetric:
self._fixed_a_zero_point = a_zero_point
self.call = self.call_fixed_asymmetric if self._is_asymmetric else self.call_fixed_symmetric
Comment thread
anko-intel marked this conversation as resolved.
self._tracker.lock()

def call_symmetric(self, inputs, batch_min_max, mask=None):
"""Apply symmetric quantization to inputs.
def call_symmetric(self, inputs, a_scale):
"""Apply symmetric quantize-dequantize with given scale.

Args:
inputs (jnp.ndarray): Input tensor.
batch_min_max (jnp.ndarray): Min/max tensor for the batch.
mask (Optional[jnp.ndarray]): Optional mask tensor.
a_scale (jnp.ndarray): Quantization scale.

Returns:
jnp.ndarray: Quantized-dequantized tensor.
"""
a_scale, _ = get_q_params(batch_min_max, self.activation_dtype, self.compute_dtype, asymmetric=False)
x = self.aquantfun(inputs, a_scale)
x = self.adequantfun(x, a_scale)
return x

def call_asymmetric(self, inputs, batch_min_max, mask=None):
"""Apply asymmetric quantization to inputs.
def call_asymmetric(self, inputs, a_scale, a_zero_point):
"""Apply asymmetric quantize-dequantize with given scale and zero point.

Args:
inputs (jnp.ndarray): Input tensor.
batch_min_max (jnp.ndarray): Min/max tensor for the batch.
mask (Optional[jnp.ndarray]): Optional mask tensor.
a_scale (jnp.ndarray): Quantization scale.
a_zero_point (jnp.ndarray): Quantization zero point.

Returns:
jnp.ndarray: Quantized-dequantized tensor.
"""
a_scale, a_zero_point = get_q_params(batch_min_max, self.activation_dtype, self.compute_dtype, asymmetric=True)
x = self.aquantfun(inputs, a_scale, a_zero_point)
x = self.adequantfun(x, a_scale, a_zero_point)
return x
Expand Down Expand Up @@ -158,8 +170,36 @@ def call(self, inputs, mask=None):
batch_min_max = keras.ops.array((batch_min, batch_max))

if self._is_asymmetric:
return self.call_asymmetric(inputs, batch_min_max, mask)
return self.call_symmetric(inputs, batch_min_max, mask)
a_scale, a_zero_point = get_q_params(
batch_min_max, self.activation_dtype, self.compute_dtype, asymmetric=True
)
return self.call_asymmetric(inputs, a_scale, a_zero_point)
a_scale, _ = get_q_params(batch_min_max, self.activation_dtype, self.compute_dtype, asymmetric=False)
return self.call_symmetric(inputs, a_scale)

def call_fixed_symmetric(self, inputs, mask=None):
"""Apply symmetric quantization using a pre-computed fixed scale.

Args:
inputs (jnp.ndarray): Input tensor.
mask (Optional[jnp.ndarray]): Optional mask tensor.

Returns:
jnp.ndarray: Quantized-dequantized tensor.
"""
return self.call_symmetric(inputs, self._fixed_a_scale)

def call_fixed_asymmetric(self, inputs, mask=None):
"""Apply asymmetric quantization using a pre-computed fixed scale.

Args:
inputs (jnp.ndarray): Input tensor.
mask (Optional[jnp.ndarray]): Optional mask tensor.

Returns:
jnp.ndarray: Quantized-dequantized tensor.
"""
return self.call_asymmetric(inputs, self._fixed_a_scale, self._fixed_a_zero_point)

Comment thread
anko-intel marked this conversation as resolved.

class QDynamicDenseMixin(SaveableLayerMixin):
Expand Down Expand Up @@ -322,7 +362,9 @@ def prepare(cls, orig, weight_dtype, activation_dtype, const_scale=False, const_
orig._is_int8 = jnp.issubdtype(activation_dtype, jnp.integer)
orig.q_qdq = DynamicQDQLayer("q_qdq", activation_dtype, orig.dtype_policy, False)
orig.k_qdq = DynamicQDQLayer("k_qdq", activation_dtype, orig.dtype_policy, orig._is_int8)
orig.a_qdq = DynamicQDQLayer("a_qdq", activation_dtype, orig.dtype_policy, orig._is_int8)
orig.a_qdq = DynamicQDQLayer(
"a_qdq", activation_dtype, orig.dtype_policy, orig._is_int8, fixed_range=(0.0, 1.0)
)
orig.v_qdq = DynamicQDQLayer("v_qdq", activation_dtype, orig.dtype_policy, False)
orig._tracker.lock()
return orig
Expand Down
68 changes: 57 additions & 11 deletions neural_compressor/jax/quantization/layers_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,19 @@ def get_calibrated_range(self):
"""
return ops.array((self.min_val, self.max_val))

def is_calibrated(self):
"""Check if the observer has valid calibration data.

Returns:
bool: True if calibrated, False if min (and internally max) are still at initial values.
"""
return not (jnp.isinf(self.min_val.value).any())
Comment thread
anko-intel marked this conversation as resolved.


class StaticQDQLayer(SaveableLayerMixin, keras.layers.Layer):
"""Layer that applies static quantize-dequantize to activations."""

def __init__(self, name, activation_dtype, dtype="float32", asymmetric=False, const_scale=False):
def __init__(self, name, activation_dtype, dtype="float32", asymmetric=False, const_scale=False, fixed_range=None):
"""Initialize the static QDQ helper layer.

Args:
Expand All @@ -157,6 +165,8 @@ def __init__(self, name, activation_dtype, dtype="float32", asymmetric=False, co
dtype (str | keras.DTypePolicy): dtype for the layer - see keras.layers.Layer API for details.
asymmetric (bool): Whether to use asymmetric quantization.
const_scale (bool): Whether to use constant scales.
fixed_range (Optional[Tuple[float, float]]): If provided, use this (min, max) range
instead of collecting calibration data via observers.

Returns:
None: Initializes the layer instance.
Expand All @@ -167,6 +177,9 @@ def __init__(self, name, activation_dtype, dtype="float32", asymmetric=False, co
self.supports_masking = True
self._is_quantized = False
self.const_scale = const_scale
self.fixed_range = fixed_range
if fixed_range is not None:
self.call = self.call_passthrough
if const_scale:
self._const_variables = ["a_scale"]
if asymmetric:
Expand All @@ -177,9 +190,13 @@ def __init__(self, name, activation_dtype, dtype="float32", asymmetric=False, co
def add_observers(self):
"""Attach observer layers for calibration.

Skipped when fixed_range is set, as no calibration data is needed.

Returns:
None: Adds observer layers.
"""
if self.fixed_range is not None:
return
self._tracker.unlock()
self.input_observer = MinMaxObserver(dtype=self.dtype_policy)
self._tracker.lock()
Expand Down Expand Up @@ -215,11 +232,16 @@ def add_variables(self):
def convert(self):
"""Compute activation scale and finalize static quantization.

Uses fixed_range if set, otherwise reads from the calibration observer.

Returns:
None: Updates activation scale variables.
"""
self._tracker.unlock()
arange = self.input_observer.get_calibrated_range()
if self.fixed_range is not None:
arange = ops.array(self.fixed_range)
else:
arange = self.input_observer.get_calibrated_range()
a_scale, a_zero_point = get_q_params(
arange, self.activation_dtype, self.compute_dtype, asymmetric=self._is_asymmetric
)
Expand Down Expand Up @@ -272,6 +294,18 @@ def call(self, inputs, mask=None):
x = self.input_observer(inputs, mask=mask)
return x

def call_passthrough(self, inputs, mask=None):
"""Pass inputs through without observation.

Args:
inputs (jnp.ndarray): Input tensor.
mask (Optional[jnp.ndarray]): Optional mask tensor.

Returns:
jnp.ndarray: Unmodified inputs.
"""
return inputs

def call_symmetric(self, inputs, mask=None):
"""Apply symmetric quantize-dequantize to inputs.

Expand Down Expand Up @@ -588,7 +622,9 @@ def prepare(cls, orig, weight_dtype, activation_dtype, const_scale=False, const_
# f_qdq is used for quantize/dequantize of query tensor on fallback path (without using dot_product_attention)
orig.f_qdq = StaticQDQLayer("f_qdq", activation_dtype, orig.dtype_policy, orig._is_int8, const_scale)
orig.k_qdq = StaticQDQLayer("k_qdq", activation_dtype, orig.dtype_policy, orig._is_int8, const_scale)
orig.a_qdq = StaticQDQLayer("a_qdq", activation_dtype, orig.dtype_policy, orig._is_int8, const_scale)
orig.a_qdq = StaticQDQLayer(
"a_qdq", activation_dtype, orig.dtype_policy, orig._is_int8, const_scale, fixed_range=(0.0, 1.0)
)
orig.v_qdq = StaticQDQLayer("v_qdq", activation_dtype, orig.dtype_policy, orig._is_int8, const_scale)
orig._is_quantized = False
orig._tracker.lock()
Expand Down Expand Up @@ -624,12 +660,24 @@ def convert(self):
Returns:
None: Updates QDQ helpers with calibrated values.
"""
self.f_qdq.convert()
# Calculate the scale for query for dot product attention path
# from the fallback path used in calibration
self.q_qdq.a_scale.assign(self.f_qdq.a_scale / self._inverse_sqrt_key_dim)
if self.q_qdq._is_asymmetric:
self.q_qdq.a_zero_point.assign(jnp.array(self.f_qdq.a_zero_point.value))
if self.q_qdq.input_observer.is_calibrated():
self.q_qdq.convert()
if not self.f_qdq.input_observer.is_calibrated():
# Calculate the scale for query in the fallback path
# from the dot product attention path used in calibration
self.f_qdq.a_scale.assign(self.q_qdq.a_scale * self._inverse_sqrt_key_dim)
if self.f_qdq._is_asymmetric:
self.f_qdq.a_zero_point.assign(jnp.array(self.q_qdq.a_zero_point.value))
else:
self.f_qdq.convert()
else:
self.f_qdq.convert()
# Calculate the scale for query for dot product attention path
# from the fallback path used in calibration
self.q_qdq.a_scale.assign(self.f_qdq.a_scale / self._inverse_sqrt_key_dim)
if self.q_qdq._is_asymmetric:
self.q_qdq.a_zero_point.assign(jnp.array(self.f_qdq.a_zero_point.value))

self.k_qdq.convert()
self.a_qdq.convert()
self.v_qdq.convert()
Expand Down Expand Up @@ -693,8 +741,6 @@ def _compute_attention(
or return_attention_scores
or (len(query.shape) != 4)
)
# For calibration always use fallback path as it can collect data for both paths
use_dot_product_attention = use_dot_product_attention and self._is_quantized

if use_dot_product_attention:
if attention_mask is not None:
Expand Down
Loading