@@ -158,12 +158,7 @@ def validate(self, model: torch.fx.GraphModule) -> None:
158158 ),
159159 )
160160
161- def test_composable_quantizer_linear_conv (self ) -> None :
162- # TODO: impelment whe dynamic quantization will be supported by OpenVINOQuantizer
163- pass
164-
165161 def test_embedding_conv_linear_quantization (self ) -> None :
166- # Mark
167162 m_eager = TestHelperModules .EmbeddingConvLinearModule ().eval ()
168163 indices = torch .tensor (
169164 [
@@ -203,57 +198,87 @@ def test_embedding_conv_linear_quantization(self) -> None:
203198 )
204199 indices = torch .unsqueeze (indices , 0 )
205200 example_inputs = (indices ,)
201+ quantizer = OpenVINOQuantizer ()
206202
207- embedding_quantizer = EmbeddingQuantizer ()
208- dynamic_quantizer = XNNPACKQuantizer ()
209- quantization_config_dynamic = get_symmetric_quantization_config (
210- is_per_channel = True , is_dynamic = True
211- )
212- dynamic_quantizer .set_global (quantization_config_dynamic )
213- static_quantizer = XNNPACKQuantizer ()
214- quantization_config = get_symmetric_quantization_config (is_per_channel = True )
215- static_quantizer .set_global (quantization_config )
216- composed_quantizer = ComposableQuantizer (
217- [embedding_quantizer , dynamic_quantizer , static_quantizer ]
218- )
203+ m = self ._quantize (m_eager , quantizer , example_inputs , is_qat = False )
219204
220- act_affine_quant_obs = observer .PlaceholderObserver .with_args (
221- dtype = torch .qint8 ,
222- qscheme = torch .per_tensor_affine ,
223- quant_min = - 128 ,
224- quant_max = 127 ,
225- eps = 2 ** - 12 ,
226- is_dynamic = True ,
227- )
228- dynamic_qconfig = QConfig (
229- activation = act_affine_quant_obs ,
230- weight = per_channel_weight_observer_range_neg_127_to_127 ,
231- )
232- qconfig = default_per_channel_symmetric_qnnpack_qconfig
233- qconfig_mapping = QConfigMapping ().set_global (qconfig )
234- qconfig_mapping .set_object_type (torch .nn .Linear , dynamic_qconfig )
235- qconfig_mapping = qconfig_mapping .set_object_type (
236- torch .nn .Embedding , float_qparams_weight_only_qconfig
237- )
238-
239- node_occurrence = {
240- torch .ops .quantized_decomposed .quantize_per_tensor .default : 4 ,
241- torch .ops .quantized_decomposed .dequantize_per_tensor .default : 4 ,
242- torch .ops .quantized_decomposed .quantize_per_tensor .tensor : 1 ,
243- torch .ops .quantized_decomposed .dequantize_per_tensor .tensor : 1 ,
244- # note: quantize op for weights are const propagated
245- torch .ops .quantized_decomposed .quantize_per_channel .default : 0 ,
246- torch .ops .quantized_decomposed .dequantize_per_channel .default : 3 ,
205+ ref_q = {
206+ # First conv
207+ "quantize_per_tensor_default" : (
208+ None ,
209+ 0.01585131697356701 ,
210+ 127 ,
211+ 0 ,
212+ 255 ,
213+ torch .uint8 ,
214+ ),
215+ "dequantize_per_tensor_default" : (
216+ None ,
217+ 0.01585131697356701 ,
218+ 127 ,
219+ 0 ,
220+ 255 ,
221+ torch .uint8 ,
222+ ),
223+ "dequantize_per_channel_default" : (
224+ None ,
225+ torch .tensor (
226+ [
227+ 0.0015 ,
228+ 0.0015 ,
229+ 0.0015 ,
230+ 0.0016 ,
231+ 0.0015 ,
232+ 0.0016 ,
233+ 0.0014 ,
234+ 0.0014 ,
235+ 0.0015 ,
236+ 0.0015 ,
237+ 0.0016 ,
238+ 0.0015 ,
239+ 0.0015 ,
240+ 0.0016 ,
241+ 0.0016 ,
242+ 0.0015 ,
243+ ]
244+ ),
245+ torch .tensor ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ]),
246+ 0 ,
247+ - 128 ,
248+ 127 ,
249+ torch .int8 ,
250+ ),
251+ # First linear
252+ "quantize_per_tensor_default_1" : (
253+ None ,
254+ 0.016017982736229897 ,
255+ 127 ,
256+ 0 ,
257+ 255 ,
258+ torch .uint8 ,
259+ ),
260+ "dequantize_per_tensor_default_1" : (
261+ None ,
262+ 0.016017982736229897 ,
263+ 127 ,
264+ 0 ,
265+ 255 ,
266+ torch .uint8 ,
267+ ),
268+ "dequantize_per_channel_default_1" : (
269+ None ,
270+ torch .tensor (
271+ [0.0019 , 0.0019 , 0.0020 , 0.0018 , 0.0019 , 0.0019 , 0.0018 , 0.0018 ]
272+ ),
273+ torch .tensor ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ]),
274+ 0 ,
275+ - 128 ,
276+ 127 ,
277+ torch .int8 ,
278+ ),
279+ # TODO: embedding
247280 }
248- self ._test_quantizer (
249- m_eager ,
250- example_inputs ,
251- composed_quantizer ,
252- node_occurrence ,
253- [],
254- True ,
255- qconfig_mapping ,
256- )
281+ self ._check_quantization_with_ref (m , ref_q )
257282
258283 def test_disallow_eval_train (self ) -> None :
259284 m = TestHelperModules .ConvWithBNRelu (relu = True )
@@ -272,7 +297,7 @@ def test_disallow_eval_train(self) -> None:
272297
273298 # After prepare: still not OK
274299 quantizer = OpenVINOQuantizer ()
275- m = prepare_qat_pt2e (m , quantizer ) # pyre-ignore[6]
300+ m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
276301 with self .assertRaises (NotImplementedError ):
277302 m .eval ()
278303 with self .assertRaises (NotImplementedError ):
@@ -308,11 +333,9 @@ class M(torch.nn.Module):
308333 def __init__ (self ) -> None :
309334 super ().__init__ ()
310335 self .bn = torch .nn .BatchNorm2d (3 )
311- self .dropout = torch .nn .Dropout (0.5 )
312336
313337 def forward (self , x ):
314338 x = self .bn (x )
315- x = self .dropout (x )
316339 return x
317340
318341 m = M ().train ()
@@ -324,8 +347,6 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None:
324347 bn_op = bn_train_op if train else bn_eval_op
325348 bn_node = self ._get_node (m , bn_op )
326349 self .assertTrue (bn_node is not None )
327- dropout_node = self ._get_node (m , torch .ops .aten .dropout .default )
328- self .assertEqual (dropout_node .args [2 ], train )
329350
330351 # Before wrapping: this is not OK
331352 with self .assertRaises (NotImplementedError ):
@@ -341,8 +362,8 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None:
341362 _assert_ops_are_correct (m , train = True ) # pyre-ignore[6]
342363
343364 # After prepare but before wrapping: this is not OK
344- quantizer = XNNPACKQuantizer ()
345- m = prepare_qat_pt2e (m , quantizer ) # pyre-ignore[6]
365+ quantizer = OpenVINOQuantizer ()
366+ m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
346367 with self .assertRaises (NotImplementedError ):
347368 m .eval ()
348369 with self .assertRaises (NotImplementedError ):
@@ -677,142 +698,6 @@ def _check_quantization_with_ref(self, model: torch.fx.GraphModule, ref: Dict):
677698
678699 assert len (ref ) == matches
679700
680- def _get_backend_config (self ):
681- def _get_linear_configs ():
682- observation_type = ObservationType .OUTPUT_SHARE_OBSERVER_WITH_INPUT
683- dtype_configs = [
684- DTypeConfig (
685- input_dtype = torch .quint8 ,
686- output_dtype = torch .float ,
687- weight_dtype = torch .qint8 ,
688- bias_dtype = torch .float ,
689- )
690- ]
691- linear_configs : list [BackendPatternConfig ] = []
692- # linear module
693- linear_configs .append (
694- BackendPatternConfig (torch .nn .Linear )
695- .set_observation_type (observation_type ) # noqa: E131
696- .set_dtype_configs (dtype_configs )
697- .set_root_module (torch .nn .Linear )
698- .set_reference_quantized_module (nnqr .Linear )
699- )
700- # functional linear
701- linear_configs .append (
702- BackendPatternConfig (torch .nn .functional .linear )
703- .set_observation_type (observation_type ) # noqa: E131
704- .set_dtype_configs (dtype_configs )
705- ._set_input_type_to_index ({"weight" : 1 , "bias" : 2 })
706- )
707- return linear_configs
708-
709- def _get_conv_configs ():
710- pass
711-
712- return BackendConfig ("OpenVINO" ).set_backend_pattern_configs (
713- _get_linear_configs ()
714- )
715- # .set_backend_pattern_configs(_get_conv_configs())
716-
717- def _test_quantizer (
718- self ,
719- model ,
720- example_inputs ,
721- quantizer ,
722- expected_node_occurrence ,
723- expected_node_list = None ,
724- check_against_fx_quant = False ,
725- fx_qconfig_mapping = None ,
726- export_with_dynamic_shape = False ,
727- is_qat = False ,
728- is_debug_mode = False ,
729- training_ir_node_occurrence = None ,
730- ):
731- # resetting dynamo cache
732- torch ._dynamo .reset ()
733- m_eager = model .eval ()
734-
735- # program capture
736- m = copy .deepcopy (m_eager )
737- dynamic_shapes = tuple (
738- {0 : torch .export .Dim ("dim" )} if i == 0 else None
739- for i in range (len (example_inputs ))
740- )
741- m = export_for_training (
742- m ,
743- example_inputs ,
744- dynamic_shapes = dynamic_shapes if export_with_dynamic_shape else None ,
745- ).module ()
746-
747- if is_qat :
748- m = prepare_qat_pt2e (m , quantizer )
749- else :
750- m = prepare_pt2e (m , quantizer )
751- if is_debug_mode :
752- print ("prepared model:" , m )
753- # Calibrate
754- m (* example_inputs )
755- m = convert_pt2e (m )
756- if is_debug_mode :
757- print ("quantized model" , m )
758-
759- pt2_quant_output = m (* example_inputs )
760- node_occurrence = {
761- ns .call_function (k ): v for k , v in expected_node_occurrence .items ()
762- }
763- if expected_node_list is None :
764- expected_node_list = []
765- node_list = [ns .call_function (n ) for n in expected_node_list ]
766- self .checkGraphModuleNodes (
767- m , expected_node_occurrence = node_occurrence , expected_node_list = node_list
768- )
769- if check_against_fx_quant :
770- qconfig_mapping = fx_qconfig_mapping
771- backend_config = self ._get_backend_config ()
772- m_copy = copy .deepcopy (m_eager )
773- m_fx = prepare_fx (
774- m_copy , qconfig_mapping , example_inputs , backend_config = backend_config
775- )
776- m_fx (* example_inputs )
777- m_fx = _convert_to_reference_decomposed_fx (
778- m_fx , backend_config = backend_config
779- )
780- m_fx = export_for_training (
781- m_fx ,
782- example_inputs ,
783- dynamic_shapes = dynamic_shapes if export_with_dynamic_shape else None ,
784- ).module ()
785- node_occurrence = {}
786- for k , v in PT2EQuantizationTestCase ._MAP_TO_FX_TRACED_OPS .items ():
787- if k in expected_node_occurrence :
788- node_occurrence [ns .call_function (v )] = expected_node_occurrence [k ]
789- if training_ir_node_occurrence is not None :
790- node_occurrence = {
791- ns .call_function (k ): v
792- for k , v in training_ir_node_occurrence .items ()
793- }
794- self .checkGraphModuleNodes (m_fx , expected_node_occurrence = node_occurrence )
795- fx_quant_output = m_fx (* example_inputs )
796- self .assertEqual (fx_quant_output , pt2_quant_output )
797- return m
798- # activation_observer = observer.HistogramObserver
799- default_qconfig = QConfig (
800- activation = activation_observer , weight = weight_observer
801- )
802- qconfig_mapping = QConfigMapping ()
803- qconfig_mapping .set_global (QConfig (activation = None , weight = None ))
804- qconfig_mapping .set_object_type (torch .nn .Linear , default_qconfig )
805- self ._quantize ()
806- self ._test_quantizer (
807- m ,
808- example_inputs ,
809- quantizer ,
810- node_occurrence ,
811- check_against_fx_quant = True ,
812- fx_qconfig_mapping = qconfig_mapping ,
813- )
814- # self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence, )
815-
816701 def test_save_load (self ) -> None :
817702 """Test save/load a quantized model"""
818703 m = self ._get_linear ()
0 commit comments