@@ -75,17 +75,17 @@ def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
7575
7676
7777if is_torchao_available ():
78- from torchao .dtypes import AffineQuantizedTensor
7978 from torchao .quantization import (
8079 Float8WeightOnlyConfig ,
80+ Int4Tensor ,
8181 Int4WeightOnlyConfig ,
8282 Int8DynamicActivationInt8WeightConfig ,
8383 Int8DynamicActivationIntxWeightConfig ,
84+ Int8Tensor ,
8485 Int8WeightOnlyConfig ,
8586 IntxWeightOnlyConfig ,
8687 )
87- from torchao .quantization .linear_activation_quantized_tensor import LinearActivationQuantizedTensor
88- from torchao .utils import get_model_size_in_bytes
88+ from torchao .utils import TorchAOBaseTensor , get_model_size_in_bytes
8989
9090
9191@require_torch
@@ -260,9 +260,7 @@ def test_int4wo_quant_bfloat16_conversion(self):
260260 )
261261
262262 weight = quantized_model .transformer_blocks [0 ].ff .net [2 ].weight
263- self .assertTrue (isinstance (weight , AffineQuantizedTensor ))
264- self .assertEqual (weight .quant_min , 0 )
265- self .assertEqual (weight .quant_max , 15 )
263+ self .assertTrue (isinstance (weight , Int4Tensor ))
266264
267265 def test_device_map (self ):
268266 """
@@ -322,7 +320,7 @@ def test_device_map(self):
322320 if "transformer_blocks.0" in device_map :
323321 self .assertTrue (isinstance (weight , nn .Parameter ))
324322 else :
325- self .assertTrue (isinstance (weight , AffineQuantizedTensor ))
323+ self .assertTrue (isinstance (weight , Int4Tensor ))
326324
327325 output = quantized_model (** inputs )[0 ]
328326 output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
@@ -343,7 +341,7 @@ def test_device_map(self):
343341 if "transformer_blocks.0" in device_map :
344342 self .assertTrue (isinstance (weight , nn .Parameter ))
345343 else :
346- self .assertTrue (isinstance (weight , AffineQuantizedTensor ))
344+ self .assertTrue (isinstance (weight , Int4Tensor ))
347345
348346 output = quantized_model (** inputs )[0 ]
349347 output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
@@ -360,11 +358,11 @@ def test_modules_to_not_convert(self):
360358
361359 unquantized_layer = quantized_model_with_not_convert .transformer_blocks [0 ].ff .net [2 ]
362360 self .assertTrue (isinstance (unquantized_layer , torch .nn .Linear ))
363- self .assertFalse (isinstance (unquantized_layer .weight , AffineQuantizedTensor ))
361+ self .assertFalse (isinstance (unquantized_layer .weight , Int8Tensor ))
364362 self .assertEqual (unquantized_layer .weight .dtype , torch .bfloat16 )
365363
366364 quantized_layer = quantized_model_with_not_convert .proj_out
367- self .assertTrue (isinstance (quantized_layer .weight , AffineQuantizedTensor ))
365+ self .assertTrue (isinstance (quantized_layer .weight , Int8Tensor ))
368366
369367 quantization_config = TorchAoConfig (Int8WeightOnlyConfig ())
370368 quantized_model = FluxTransformer2DModel .from_pretrained (
@@ -448,18 +446,18 @@ def test_memory_footprint(self):
448446
449447 # Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64
450448 for block in transformer_int4wo .transformer_blocks :
451- self .assertTrue (isinstance (block .ff .net [2 ].weight , AffineQuantizedTensor ))
452- self .assertTrue (isinstance (block .ff_context .net [2 ].weight , AffineQuantizedTensor ))
449+ self .assertTrue (isinstance (block .ff .net [2 ].weight , Int4Tensor ))
450+ self .assertTrue (isinstance (block .ff_context .net [2 ].weight , Int4Tensor ))
453451
454452 # Will quantize all the linear layers except x_embedder
455453 for name , module in transformer_int4wo_gs32 .named_modules ():
456454 if isinstance (module , nn .Linear ) and name not in ["x_embedder" ]:
457- self .assertTrue (isinstance (module .weight , AffineQuantizedTensor ))
455+ self .assertTrue (isinstance (module .weight , Int4Tensor ))
458456
459457 # Will quantize all the linear layers
460458 for module in transformer_int8wo .modules ():
461459 if isinstance (module , nn .Linear ):
462- self .assertTrue (isinstance (module .weight , AffineQuantizedTensor ))
460+ self .assertTrue (isinstance (module .weight , Int8Tensor ))
463461
464462 total_int4wo = get_model_size_in_bytes (transformer_int4wo )
465463 total_int4wo_gs32 = get_model_size_in_bytes (transformer_int4wo_gs32 )
@@ -588,7 +586,7 @@ def _test_original_model_expected_slice(self, quant_type, expected_slice):
588586 output = quantized_model (** inputs )[0 ]
589587 output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
590588 weight = quantized_model .transformer_blocks [0 ].ff .net [2 ].weight
591- self .assertTrue (isinstance (weight , ( AffineQuantizedTensor , LinearActivationQuantizedTensor ) ))
589+ self .assertTrue (isinstance (weight , TorchAOBaseTensor ))
592590 self .assertTrue (numpy_cosine_similarity_distance (output_slice , expected_slice ) < 1e-3 )
593591
594592 def _check_serialization_expected_slice (self , quant_type , expected_slice , device ):
@@ -604,11 +602,7 @@ def _check_serialization_expected_slice(self, quant_type, expected_slice, device
604602 output = loaded_quantized_model (** inputs )[0 ]
605603
606604 output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
607- self .assertTrue (
608- isinstance (
609- loaded_quantized_model .proj_out .weight , (AffineQuantizedTensor , LinearActivationQuantizedTensor )
610- )
611- )
605+ self .assertTrue (isinstance (loaded_quantized_model .proj_out .weight , TorchAOBaseTensor ))
612606 self .assertTrue (numpy_cosine_similarity_distance (output_slice , expected_slice ) < 1e-3 )
613607
614608 def test_int_a8w8_accelerator (self ):
@@ -756,7 +750,7 @@ def _test_quant_type(self, quantization_config, expected_slice):
756750 pipe .enable_model_cpu_offload ()
757751
758752 weight = pipe .transformer .transformer_blocks [0 ].ff .net [2 ].weight
759- self .assertTrue (isinstance (weight , ( AffineQuantizedTensor , LinearActivationQuantizedTensor ) ))
753+ self .assertTrue (isinstance (weight , TorchAOBaseTensor ))
760754
761755 inputs = self .get_dummy_inputs (torch_device )
762756 output = pipe (** inputs )[0 ].flatten ()
@@ -790,7 +784,7 @@ def test_serialization_int8wo(self):
790784 pipe .enable_model_cpu_offload ()
791785
792786 weight = pipe .transformer .x_embedder .weight
793- self .assertTrue (isinstance (weight , AffineQuantizedTensor ))
787+ self .assertTrue (isinstance (weight , Int8Tensor ))
794788
795789 inputs = self .get_dummy_inputs (torch_device )
796790 output = pipe (** inputs )[0 ].flatten ()[:128 ]
@@ -809,7 +803,7 @@ def test_serialization_int8wo(self):
809803 pipe .enable_model_cpu_offload ()
810804
811805 weight = transformer .x_embedder .weight
812- self .assertTrue (isinstance (weight , AffineQuantizedTensor ))
806+ self .assertTrue (isinstance (weight , Int8Tensor ))
813807
814808 loaded_output = pipe (** inputs )[0 ].flatten ()[:128 ]
815809 # Seems to require higher tolerance depending on which machine it is being run.
@@ -897,7 +891,7 @@ def test_transformer_int8wo(self):
897891 # Verify that all linear layer weights are quantized
898892 for name , module in pipe .transformer .named_modules ():
899893 if isinstance (module , nn .Linear ):
900- self .assertTrue (isinstance (module .weight , AffineQuantizedTensor ))
894+ self .assertTrue (isinstance (module .weight , Int8Tensor ))
901895
902896 # Verify outputs match expected slice
903897 inputs = self .get_dummy_inputs (torch_device )
0 commit comments