Skip to content

Commit c5076b6

Browse files
committed
Implement _dequantize for TorchAO quantizer
- Add _dequantize() method in TorchAoHfQuantizer that dequantizes TorchAOBaseTensor weights back to standard nn.Parameter - Fix _verify_if_layer_quantized to check isinstance(weight, TorchAOBaseTensor) so dequantized layers are correctly detected as non-quantized
1 parent c8c8401 commit c5076b6

2 files changed

Lines changed: 19 additions & 0 deletions

File tree

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,17 @@ def is_trainable(self):
376376
@property
377377
def is_compileable(self) -> bool:
378378
return True
379+
380+
def _dequantize(self, model):
381+
from torchao.utils import TorchAOBaseTensor
382+
383+
for name, module in model.named_modules():
384+
if isinstance(module, nn.Linear) and isinstance(module.weight, TorchAOBaseTensor):
385+
device = module.weight.device
386+
dequantized_weight = module.weight.dequantize().to(device)
387+
module.weight = nn.Parameter(dequantized_weight)
388+
# Reset extra_repr if it was overridden
389+
if hasattr(module.extra_repr, "__func__") and module.extra_repr.__func__ is not nn.Linear.extra_repr:
390+
module.extra_repr = types.MethodType(nn.Linear.extra_repr, module)
391+
392+
return model

tests/models/testing_utils/quantization.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,12 @@ def _create_quantized_model(self, config_name, **extra_kwargs):
829829
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
830830

831831
def _verify_if_layer_quantized(self, name, module, config_kwargs):
832+
from torchao.utils import TorchAOBaseTensor
833+
832834
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
835+
assert isinstance(module.weight, TorchAOBaseTensor), (
836+
f"Layer {name} weight is {type(module.weight)}, expected TorchAOBaseTensor"
837+
)
833838

834839

835840
# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)

0 commit comments

Comments
 (0)