Skip to content

Commit 24b4c25

Browse files
Remove references to torchao's AffineQuantizedTensor (#13405)
**Summary:** TorchAO recently deprecated AffineQuantizedTensor and related classes (pytorch/ao#2752). These will be removed in the next release. We should remove references of these classes in diffusers before then. **Test Plan:** python -m pytest -s -v tests/quantization/torchao/test_torchao.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent d31061b commit 24b4c25

File tree

2 files changed

+23
-38
lines changed

2 files changed

+23
-38
lines changed

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,10 @@ def fuzzy_match_size(config_name: str) -> str | None:
133133
return None
134134

135135

136-
def _quantization_type(weight):
137-
from torchao.dtypes import AffineQuantizedTensor
138-
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
139-
140-
if isinstance(weight, AffineQuantizedTensor):
141-
return f"{weight.__class__.__name__}({weight._quantization_type()})"
142-
143-
if isinstance(weight, LinearActivationQuantizedTensor):
144-
return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"
145-
146-
147136
def _linear_extra_repr(self):
148-
weight = _quantization_type(self.weight)
137+
from torchao.utils import TorchAOBaseTensor
138+
139+
weight = self.weight.__class__.__name__ if isinstance(self.weight, TorchAOBaseTensor) else None
149140
if weight is None:
150141
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
151142
else:
@@ -283,12 +274,12 @@ def create_quantized_param(
283274

284275
if self.pre_quantized:
285276
# If we're loading pre-quantized weights, replace the repr of linear layers for pretty printing info
286-
# about AffineQuantizedTensor
277+
# about the quantized tensor type
287278
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
288279
if isinstance(module, nn.Linear):
289280
module.extra_repr = types.MethodType(_linear_extra_repr, module)
290281
else:
291-
# As we perform quantization here, the repr of linear layers is that of AQT, so we don't have to do it ourselves
282+
# As we perform quantization here, the repr of linear layers is set by TorchAO, so we don't have to do it ourselves
292283
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
293284
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
294285

tests/quantization/torchao/test_torchao.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,17 @@ def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
7575

7676

7777
if is_torchao_available():
78-
from torchao.dtypes import AffineQuantizedTensor
7978
from torchao.quantization import (
8079
Float8WeightOnlyConfig,
80+
Int4Tensor,
8181
Int4WeightOnlyConfig,
8282
Int8DynamicActivationInt8WeightConfig,
8383
Int8DynamicActivationIntxWeightConfig,
84+
Int8Tensor,
8485
Int8WeightOnlyConfig,
8586
IntxWeightOnlyConfig,
8687
)
87-
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
88-
from torchao.utils import get_model_size_in_bytes
88+
from torchao.utils import TorchAOBaseTensor, get_model_size_in_bytes
8989

9090

9191
@require_torch
@@ -260,9 +260,7 @@ def test_int4wo_quant_bfloat16_conversion(self):
260260
)
261261

262262
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
263-
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
264-
self.assertEqual(weight.quant_min, 0)
265-
self.assertEqual(weight.quant_max, 15)
263+
self.assertTrue(isinstance(weight, Int4Tensor))
266264

267265
def test_device_map(self):
268266
"""
@@ -322,7 +320,7 @@ def test_device_map(self):
322320
if "transformer_blocks.0" in device_map:
323321
self.assertTrue(isinstance(weight, nn.Parameter))
324322
else:
325-
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
323+
self.assertTrue(isinstance(weight, Int4Tensor))
326324

327325
output = quantized_model(**inputs)[0]
328326
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
@@ -343,7 +341,7 @@ def test_device_map(self):
343341
if "transformer_blocks.0" in device_map:
344342
self.assertTrue(isinstance(weight, nn.Parameter))
345343
else:
346-
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
344+
self.assertTrue(isinstance(weight, Int4Tensor))
347345

348346
output = quantized_model(**inputs)[0]
349347
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
@@ -360,11 +358,11 @@ def test_modules_to_not_convert(self):
360358

361359
unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2]
362360
self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear))
363-
self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor))
361+
self.assertFalse(isinstance(unquantized_layer.weight, Int8Tensor))
364362
self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16)
365363

366364
quantized_layer = quantized_model_with_not_convert.proj_out
367-
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
365+
self.assertTrue(isinstance(quantized_layer.weight, Int8Tensor))
368366

369367
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
370368
quantized_model = FluxTransformer2DModel.from_pretrained(
@@ -448,18 +446,18 @@ def test_memory_footprint(self):
448446

449447
# Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64
450448
for block in transformer_int4wo.transformer_blocks:
451-
self.assertTrue(isinstance(block.ff.net[2].weight, AffineQuantizedTensor))
452-
self.assertTrue(isinstance(block.ff_context.net[2].weight, AffineQuantizedTensor))
449+
self.assertTrue(isinstance(block.ff.net[2].weight, Int4Tensor))
450+
self.assertTrue(isinstance(block.ff_context.net[2].weight, Int4Tensor))
453451

454452
# Will quantize all the linear layers except x_embedder
455453
for name, module in transformer_int4wo_gs32.named_modules():
456454
if isinstance(module, nn.Linear) and name not in ["x_embedder"]:
457-
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
455+
self.assertTrue(isinstance(module.weight, Int4Tensor))
458456

459457
# Will quantize all the linear layers
460458
for module in transformer_int8wo.modules():
461459
if isinstance(module, nn.Linear):
462-
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
460+
self.assertTrue(isinstance(module.weight, Int8Tensor))
463461

464462
total_int4wo = get_model_size_in_bytes(transformer_int4wo)
465463
total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32)
@@ -588,7 +586,7 @@ def _test_original_model_expected_slice(self, quant_type, expected_slice):
588586
output = quantized_model(**inputs)[0]
589587
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
590588
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
591-
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
589+
self.assertTrue(isinstance(weight, TorchAOBaseTensor))
592590
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
593591

594592
def _check_serialization_expected_slice(self, quant_type, expected_slice, device):
@@ -604,11 +602,7 @@ def _check_serialization_expected_slice(self, quant_type, expected_slice, device
604602
output = loaded_quantized_model(**inputs)[0]
605603

606604
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
607-
self.assertTrue(
608-
isinstance(
609-
loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)
610-
)
611-
)
605+
self.assertTrue(isinstance(loaded_quantized_model.proj_out.weight, TorchAOBaseTensor))
612606
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
613607

614608
def test_int_a8w8_accelerator(self):
@@ -756,7 +750,7 @@ def _test_quant_type(self, quantization_config, expected_slice):
756750
pipe.enable_model_cpu_offload()
757751

758752
weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight
759-
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
753+
self.assertTrue(isinstance(weight, TorchAOBaseTensor))
760754

761755
inputs = self.get_dummy_inputs(torch_device)
762756
output = pipe(**inputs)[0].flatten()
@@ -790,7 +784,7 @@ def test_serialization_int8wo(self):
790784
pipe.enable_model_cpu_offload()
791785

792786
weight = pipe.transformer.x_embedder.weight
793-
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
787+
self.assertTrue(isinstance(weight, Int8Tensor))
794788

795789
inputs = self.get_dummy_inputs(torch_device)
796790
output = pipe(**inputs)[0].flatten()[:128]
@@ -809,7 +803,7 @@ def test_serialization_int8wo(self):
809803
pipe.enable_model_cpu_offload()
810804

811805
weight = transformer.x_embedder.weight
812-
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
806+
self.assertTrue(isinstance(weight, Int8Tensor))
813807

814808
loaded_output = pipe(**inputs)[0].flatten()[:128]
815809
# Seems to require higher tolerance depending on which machine it is being run.
@@ -897,7 +891,7 @@ def test_transformer_int8wo(self):
897891
# Verify that all linear layer weights are quantized
898892
for name, module in pipe.transformer.named_modules():
899893
if isinstance(module, nn.Linear):
900-
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
894+
self.assertTrue(isinstance(module.weight, Int8Tensor))
901895

902896
# Verify outputs match expected slice
903897
inputs = self.get_dummy_inputs(torch_device)

0 commit comments

Comments
 (0)