diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 30d44a92c425..2d10a090d229 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -142,6 +142,14 @@ def _is_module_quantized(self, module): except (AssertionError, AttributeError): return False + def _get_dummy_inputs_for_model(self, model): + inputs = self.get_dummy_inputs() + model_dtype = next(model.parameters()).dtype + return { + k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v + for k, v in inputs.items() + } + def _load_unquantized_model(self): kwargs = getattr(self, "pretrained_model_kwargs", {}) return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) @@ -174,7 +182,7 @@ def _test_quantization_inference(self, config_kwargs): model_quantized = self._create_quantized_model(config_kwargs) model_quantized.to(torch_device) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model_quantized) output = model_quantized(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" @@ -222,7 +230,8 @@ def _test_quantization_lora_inference(self, config_kwargs): # Move LoRA adapter weights to device (they default to CPU) model.to(torch_device) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) + output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None with LoRA" @@ -236,7 +245,8 @@ def _test_quantization_serialization(self, config_kwargs, tmp_path): model_loaded = self.model_class.from_pretrained(str(tmp_path)) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model_loaded) + output = model_loaded(**inputs, return_dict=False)[0] assert not torch.isnan(output).any(), "Loaded model output contains NaN" @@ -334,7 +344,8 @@ def _test_quantization_device_map(self, config_kwargs): assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute" assert model.hf_device_map is not None, "hf_device_map should not be None" - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) + output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" assert not torch.isnan(output).any(), "Model output contains NaN" @@ -359,7 +370,8 @@ def _test_dequantize(self, config_kwargs): if isinstance(module, torch.nn.Linear): assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()" - inputs = self.get_dummy_inputs() + # Get model dtype from first parameter + inputs = self._get_dummy_inputs_for_model(model) output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None after dequantization" assert not torch.isnan(output).any(), "Model output contains NaN after dequantization" @@ -405,7 +417,7 @@ def _test_quantization_training(self, config_kwargs): pytest.skip("No attention layers found in model for adapter training test") # Step 3: run forward and backward pass - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) with torch.amp.autocast(torch_device, dtype=torch.float16): out = model(**inputs, return_dict=False)[0] @@ -587,7 +599,7 @@ def test_bnb_keep_modules_in_fp32(self): f"Module {name} should be uint8 but is {module.weight.dtype}" ) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) _ = model(**inputs) def test_bnb_modules_to_not_convert(self): @@ -908,7 +920,8 @@ def test_torchao_quantization_serialization(self, quant_type, tmp_path): model_loaded = self.model_class.from_pretrained(str(tmp_path), device_map=str(torch_device)) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model_loaded) + output = model_loaded(**inputs, return_dict=False)[0] assert not torch.isnan(output).any(), "Loaded model output contains NaN" @@ -1165,6 +1178,14 @@ class QuantizationCompileTesterMixin: - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass """ + def _get_dummy_inputs_for_model(self, model): + inputs = self.get_dummy_inputs() + model_dtype = next(model.parameters()).dtype + return { + k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v + for k, v in inputs.items() + } + def setup_method(self): gc.collect() backend_empty_cache(torch_device) @@ -1190,7 +1211,8 @@ def _test_torch_compile(self, config_kwargs): model = torch.compile(model, fullgraph=True) with torch._dynamo.config.patch(error_on_recompile=True): - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) + output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" assert not torch.isnan(output).any(), "Model output contains NaN" @@ -1221,7 +1243,8 @@ def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False model.enable_group_offload(**group_offload_kwargs) model = torch.compile(model) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) + output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" assert not torch.isnan(output).any(), "Model output contains NaN"