@@ -359,6 +359,9 @@ def _test_dequantize(self, config_kwargs):
359359 if isinstance (module , torch .nn .Linear ):
360360 assert not self ._is_module_quantized (module ), f"Module { name } is still quantized after dequantize()"
361361
362+ # Cast model to the expected dtype after dequantization (weights may be float32)
363+ model .to (self .torch_dtype )
364+
362365 inputs = self .get_dummy_inputs ()
363366 output = model (** inputs , return_dict = False )[0 ]
364367 assert output is not None , "Model output is None after dequantization"
@@ -931,18 +934,9 @@ def test_torchao_device_map(self):
931934 """Test that device_map='auto' works correctly with quantization."""
932935 self ._test_quantization_device_map (TorchAoConfigMixin .TORCHAO_QUANT_TYPES ["int8wo" ])
933936
934- @pytest .mark .parametrize (
935- "quant_type" ,
936- [
937- pytest .param ("int4wo" , marks = _int4wo_skip ),
938- "int8wo" ,
939- "int8dq" ,
940- ],
941- ids = ["int4wo" , "int8wo" , "int8dq" ],
942- )
943- def test_torchao_dequantize (self , quant_type ):
937+ def test_torchao_dequantize (self ):
944938 """Test that dequantize() works correctly."""
945- self ._test_dequantize (TorchAoConfigMixin .TORCHAO_QUANT_TYPES [quant_type ])
939+ self ._test_dequantize (TorchAoConfigMixin .TORCHAO_QUANT_TYPES ["int8wo" ])
946940
947941 def test_torchao_training (self ):
948942 """Test that quantized models can be used for training with adapters."""
0 commit comments