@@ -359,6 +359,10 @@ def _test_dequantize(self, config_kwargs):
359359 if isinstance (module , torch .nn .Linear ):
360360 assert not self ._is_module_quantized (module ), f"Module { name } is still quantized after dequantize()"
361361
362+ # Cast model parameters to the expected dtype after dequantization (weights may be float32)
363+ for param in model .parameters ():
364+ param .data = param .data .to (self .torch_dtype )
365+
362366 inputs = self .get_dummy_inputs ()
363367 output = model (** inputs , return_dict = False )[0 ]
364368 assert output is not None , "Model output is None after dequantization"
@@ -931,18 +935,9 @@ def test_torchao_device_map(self):
931935 """Test that device_map='auto' works correctly with quantization."""
932936 self ._test_quantization_device_map (TorchAoConfigMixin .TORCHAO_QUANT_TYPES ["int8wo" ])
933937
934- @pytest .mark .parametrize (
935- "quant_type" ,
936- [
937- pytest .param ("int4wo" , marks = _int4wo_skip ),
938- "int8wo" ,
939- "int8dq" ,
940- ],
941- ids = ["int4wo" , "int8wo" , "int8dq" ],
942- )
943- def test_torchao_dequantize (self , quant_type ):
938+ def test_torchao_dequantize (self ):
944939 """Test that dequantize() works correctly."""
945- self ._test_dequantize (TorchAoConfigMixin .TORCHAO_QUANT_TYPES [quant_type ])
940+ self ._test_dequantize (TorchAoConfigMixin .TORCHAO_QUANT_TYPES ["int8wo" ])
946941
947942 def test_torchao_training (self ):
948943 """Test that quantized models can be used for training with adapters."""
0 commit comments