Skip to content

Commit 907c0c2

Browse files
Implement _dequantize for TorchAO quantizer (#13538)
* Implement _dequantize for TorchAO quantizer - Add _dequantize() method in TorchAoHfQuantizer that dequantizes TorchAOBaseTensor weights back to standard nn.Parameter - Fix _verify_if_layer_quantized to check isinstance(weight, TorchAOBaseTensor) so dequantized layers are correctly detected as non-quantized * enable dequantize for TorchAO tester mixin Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * check dequantize Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix dequantize: clear is_quantized flag and cast dtype after dequantize * fix Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix error report Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 387a471 commit 907c0c2

4 files changed

Lines changed: 33 additions & 0 deletions

File tree

src/diffusers/quantizers/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def dequantize(self, model):
206206

207207
# Delete quantizer and quantization config
208208
del model.hf_quantizer
209+
model.is_quantized = False
209210

210211
return model
211212

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,22 @@ def is_trainable(self):
376376
@property
377377
def is_compileable(self) -> bool:
378378
return True
379+
380+
def _dequantize(self, model):
381+
from torchao.utils import TorchAOBaseTensor
382+
383+
for name, module in model.named_modules():
384+
if isinstance(module, nn.Linear) and isinstance(module.weight, TorchAOBaseTensor):
385+
if not hasattr(module.weight, "dequantize"):
386+
raise NotImplementedError(
387+
f"Dequantization is not supported for {type(module.weight).__name__} "
388+
f"(module: {name}). Please use a quantization type that supports dequantization."
389+
)
390+
device = module.weight.device
391+
dequantized_weight = module.weight.dequantize().to(device)
392+
module.weight = nn.Parameter(dequantized_weight)
393+
# Reset extra_repr if it was overridden
394+
if hasattr(module.extra_repr, "__func__") and module.extra_repr.__func__ is not nn.Linear.extra_repr:
395+
module.extra_repr = types.MethodType(nn.Linear.extra_repr, module)
396+
397+
return model

tests/models/testing_utils/quantization.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,12 @@ def _create_quantized_model(self, config_name, **extra_kwargs):
822822
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
823823

824824
def _verify_if_layer_quantized(self, name, module, config_kwargs):
825+
from torchao.utils import TorchAOBaseTensor
826+
825827
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
828+
assert isinstance(module.weight, TorchAOBaseTensor), (
829+
f"Layer {name} weight is {type(module.weight)}, expected TorchAOBaseTensor"
830+
)
826831

827832

828833
# int4wo requires CUDA or XPU ops (_convert_weight_to_int4pack)

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,10 @@ def pretrained_model_kwargs(self):
368368
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
369369
"""TorchAO quantization tests for Flux Transformer."""
370370

371+
@property
372+
def torch_dtype(self):
373+
return torch.bfloat16
374+
371375

372376
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
373377
@property
@@ -404,6 +408,10 @@ class TestFluxTransformerQuantoCompile(FluxTransformerTesterConfig, QuantoCompil
404408
class TestFluxTransformerTorchAoCompile(FluxTransformerTesterConfig, TorchAoCompileTesterMixin):
405409
"""TorchAO + compile tests for Flux Transformer."""
406410

411+
@property
412+
def torch_dtype(self):
413+
return torch.bfloat16
414+
407415

408416
class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTesterMixin):
409417
@property

0 commit comments

Comments
 (0)