Skip to content

Commit c5b0875

Browse files
ratgrtensorflower-gardener
authored andcommitted
Fix an issue where multiple instances of DefaultNBitQuantizeRegistry could interfere with each other's configuration.
PiperOrigin-RevId: 931228650
1 parent 97811c3 commit c5b0875

4 files changed

Lines changed: 36 additions & 19 deletions

File tree

tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/BUILD

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ py_library(
6464
],
6565
# strict_deps = True,
6666
deps = [
67+
":default_n_bit_quantize_configs",
68+
":default_n_bit_quantizers",
6769
# tensorflow dep1,
6870
"//tensorflow_model_optimization/python/core/keras:compat",
6971
"//tensorflow_model_optimization/python/core/quantization/keras:quantize_config",
7072
"//tensorflow_model_optimization/python/core/quantization/keras:quantize_registry",
7173
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
72-
"//tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit:default_n_bit_quantize_configs",
73-
"//tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit:default_n_bit_quantizers",
7474
],
7575
)
7676

@@ -88,6 +88,7 @@ py_test(
8888
# tensorflow dep1,
8989
"//tensorflow_model_optimization/python/core/keras:compat",
9090
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
91+
"//tensorflow_model_optimization/python/core/quantization/keras:utils",
9192
],
9293
)
9394

tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from typing import Any, Dict
2222

23-
import tensorflow as tf
2423

2524
from tensorflow_model_optimization.python.core.keras.compat import keras
2625
from tensorflow_model_optimization.python.core.quantization.keras import quantize_config
@@ -185,9 +184,15 @@ def __init__(self, disable_per_axis=False,
185184
self._num_bits_activation = num_bits_activation
186185
self._layer_quantize_map = {}
187186
for quantize_info in self._LAYER_QUANTIZE_INFO:
188-
quantize_info.num_bits_weight = num_bits_weight
189-
quantize_info.num_bits_activation = num_bits_activation
190-
self._layer_quantize_map[quantize_info.layer_type] = quantize_info
187+
new_quantize_info = _QuantizeInfo(
188+
layer_type=quantize_info.layer_type,
189+
weight_attrs=quantize_info.weight_attrs,
190+
activation_attrs=quantize_info.activation_attrs,
191+
quantize_output=quantize_info.quantize_output,
192+
num_bits_weight=num_bits_weight,
193+
num_bits_activation=num_bits_activation,
194+
)
195+
self._layer_quantize_map[new_quantize_info.layer_type] = new_quantize_info
191196

192197
# Hack for `Activation` layer. That is the only layer with a separate
193198
# QuantizeConfig.
@@ -548,9 +553,13 @@ class DefaultNBitConvQuantizeConfig(DefaultNBitQuantizeConfig):
548553
def __init__(self, weight_attrs, activation_attrs,
549554
quantize_output, num_bits_weight: int = 8,
550555
num_bits_activation: int = 8):
551-
super(DefaultNBitConvQuantizeConfig,
552-
self).__init__(weight_attrs, activation_attrs,
553-
quantize_output, num_bits_weight, num_bits_activation)
556+
super().__init__(
557+
weight_attrs,
558+
activation_attrs,
559+
quantize_output,
560+
num_bits_weight,
561+
num_bits_activation,
562+
)
554563
self._num_bits_weight = num_bits_weight
555564
self._num_bits_activation = num_bits_activation
556565
self.weight_quantizer = n_bit_quantizers.DefaultNBitConvWeightsQuantizer(
@@ -563,17 +572,22 @@ class DefaultNBitConvTransposeQuantizeConfig(
563572

564573
def __init__(self, weight_attrs, activation_attrs, quantize_output,
565574
num_bits_weight: int = 8, num_bits_activation: int = 8):
566-
super(DefaultNBitConvTransposeQuantizeConfig,
567-
self).__init__(weight_attrs, activation_attrs, quantize_output,
568-
num_bits_weight, num_bits_activation)
575+
super().__init__(
576+
weight_attrs,
577+
activation_attrs,
578+
quantize_output,
579+
num_bits_weight,
580+
num_bits_activation,
581+
)
569582
self._num_bits_weight = num_bits_weight
570583
self._num_bits_activation = num_bits_activation
571584

572585
self.weight_quantizer = n_bit_quantizers.DefaultNBitConvTransposeWeightsQuantizer(
573586
num_bits_weight, num_bits_activation)
574587

575588

576-
def _types_dict():
589+
def types_dict():
590+
"""Returns a dictionary mapping type names to classes for deserialization."""
577591
return {
578592
'DefaultNBitQuantizeConfig':
579593
DefaultNBitQuantizeConfig,

tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class QuantizeRegistryTest(
7878
tf.test.TestCase, parameterized.TestCase, _TestHelper):
7979

8080
def setUp(self):
81-
super(QuantizeRegistryTest, self).setUp()
81+
super().setUp()
8282
self.quantize_registry = n_bit_registry.DefaultNBitQuantizeRegistry(
8383
num_bits_weight=4, num_bits_activation=4)
8484

@@ -114,7 +114,7 @@ class MinimalRNNCell(l.Layer):
114114
def __init__(self, units, **kwargs):
115115
self.units = units
116116
self.state_size = units
117-
super(MinimalRNNCell, self).__init__(**kwargs)
117+
super().__init__(**kwargs)
118118

119119
self.assertFalse(
120120
self.quantize_registry.supports(l.RNN(cell=MinimalRNNCell(10))))
@@ -374,15 +374,16 @@ def testSerialization(self):
374374
quantize_config_from_config = deserialize_keras_object(
375375
serialized_quantize_config,
376376
module_objects=globals(),
377-
custom_objects=n_bit_registry._types_dict())
377+
custom_objects=n_bit_registry.types_dict(),
378+
)
378379

379380
self.assertEqual(quantize_config, quantize_config_from_config)
380381

381382

382383
class DefaultNBitQuantizeConfigRNNTest(tf.test.TestCase, _TestHelper):
383384

384385
def setUp(self):
385-
super(DefaultNBitQuantizeConfigRNNTest, self).setUp()
386+
super().setUp()
386387

387388
self.cell1 = l.LSTMCell(3)
388389
self.cell2 = l.GRUCell(2)
@@ -493,7 +494,8 @@ def testSerialization(self):
493494
quantize_config_from_config = deserialize_keras_object(
494495
serialized_quantize_config,
495496
module_objects=globals(),
496-
custom_objects=n_bit_registry._types_dict())
497+
custom_objects=n_bit_registry.types_dict(),
498+
)
497499

498500
self.assertEqual(self.quantize_config, quantize_config_from_config)
499501

tensorflow_model_optimization/python/core/quantization/keras/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def quantize_scope(*args):
7474
'FixedQuantizeConfig': quantize_config_mod.FixedQuantizeConfig,
7575
}
7676
quantization_objects.update(default_8bit_quantize_registry._types_dict()) # pylint: disable=protected-access
77-
quantization_objects.update(default_n_bit_quantize_registry._types_dict()) # pylint: disable=protected-access
77+
quantization_objects.update(default_n_bit_quantize_registry.types_dict())
7878
quantization_objects.update(quantizers._types_dict()) # pylint: disable=protected-access
7979

8080
return keras.utils.custom_object_scope(*(args + (quantization_objects,)))

0 commit comments

Comments
 (0)