|
14 | 14 |
|
15 | 15 | """Tests for FP8 model + PEFT dequantization logic.""" |
16 | 16 |
|
17 | | -from unittest.mock import MagicMock |
| 17 | +from unittest.mock import MagicMock, patch |
18 | 18 |
|
19 | | -from nemo_automodel._transformers.auto_model import _maybe_dequantize_fp8_for_peft |
| 19 | +from nemo_automodel._transformers.auto_model import _BaseNeMoAutoModelClass, _maybe_dequantize_fp8_for_peft |
20 | 20 |
|
21 | 21 | # --------------------------------------------------------------------------- |
22 | 22 | # Tests: FP8 + PEFT auto-dequantize |
@@ -173,3 +173,149 @@ def test_meta_device_disabled_single_gpu_hf_model(self): |
173 | 173 | hf_native_quant_cfg=None, |
174 | 174 | ) |
175 | 175 | assert result is False |
| 176 | + |
| 177 | + |
| 178 | +# --------------------------------------------------------------------------- |
| 179 | +# Tests: kwargs["config"] injection gated on is_hf_model (issue #2164) |
| 180 | +# --------------------------------------------------------------------------- |
| 181 | + |
| 182 | + |
| 183 | +class TestKwargsConfigInjectionGate: |
| 184 | + """Tests for the is_hf_model gate on kwargs["config"] = _hf_config injection. |
| 185 | +
|
| 186 | + Custom models receive _hf_config positionally in model_init.py:783 via |
| 187 | + model_cls(hf_config, *model_args, **kwargs); injecting config into kwargs |
| 188 | + causes a TypeError ("got multiple values for argument 'config'"). The gate |
| 189 | + suppresses the injection for the custom-model path while preserving the |
| 190 | + in-place dequantize=True mutation needed by the HF path. |
| 191 | + """ |
| 192 | + |
| 193 | + @staticmethod |
| 194 | + def _make_build_kwargs(is_hf_model): |
| 195 | + """Minimal kwargs for running _build_model through the FP8+PEFT gate.""" |
| 196 | + mesh = MagicMock() |
| 197 | + mesh.tp_size = 1 |
| 198 | + mesh.cp_size = 1 |
| 199 | + return dict( |
| 200 | + is_hf_model=is_hf_model, |
| 201 | + use_liger_kernel=False, |
| 202 | + use_sdpa_patching=False, |
| 203 | + sdpa_method=None, |
| 204 | + torch_dtype="auto", |
| 205 | + attn_implementation="eager", |
| 206 | + quantization_config=None, |
| 207 | + force_hf=False, |
| 208 | + model_wrapper=None, |
| 209 | + autopipeline=None, |
| 210 | + parallelize_fn=None, |
| 211 | + qat_quantizer=None, |
| 212 | + mesh=mesh, |
| 213 | + loss_fn=None, |
| 214 | + peft_config=MagicMock(), |
| 215 | + fp8_config=None, |
| 216 | + compile_config=None, |
| 217 | + load_base_model=True, |
| 218 | + ) |
| 219 | + |
| 220 | + @staticmethod |
| 221 | + def _run_build_model_with_native_fp8(is_hf_model): |
| 222 | + quant_cfg = {"quant_method": "fp8", "dequantize": False} |
| 223 | + hf_config = MagicMock() |
| 224 | + hf_config.quantization_config = quant_cfg |
| 225 | + sentinel_model = MagicMock() |
| 226 | + |
| 227 | + with ( |
| 228 | + patch("nemo_automodel._transformers.auto_model._apply_preload_overrides", return_value=("eager", False)), |
| 229 | + patch("nemo_automodel._transformers.auto_model.get_hf_config", return_value=hf_config), |
| 230 | + patch("nemo_automodel._transformers.auto_model._init_model") as mock_init, |
| 231 | + patch("nemo_automodel._transformers.auto_model.get_world_size_safe", return_value=1), |
| 232 | + patch("nemo_automodel._transformers.auto_model._verify_sdpa_support"), |
| 233 | + patch( |
| 234 | + "nemo_automodel._transformers.capabilities.attach_capabilities_and_validate", |
| 235 | + return_value=sentinel_model, |
| 236 | + ), |
| 237 | + patch("nemo_automodel._transformers.auto_model.apply_model_infrastructure", return_value=sentinel_model), |
| 238 | + patch("torch.cuda.current_device", return_value=0), |
| 239 | + ): |
| 240 | + mock_init.return_value = (not is_hf_model, sentinel_model) |
| 241 | + result = _BaseNeMoAutoModelClass._build_model( |
| 242 | + "some-model", |
| 243 | + **TestKwargsConfigInjectionGate._make_build_kwargs(is_hf_model), |
| 244 | + ) |
| 245 | + |
| 246 | + return quant_cfg, hf_config, result, sentinel_model, mock_init |
| 247 | + |
| 248 | + def test_build_model_hf_fp8_peft_injects_config_kwarg(self): |
| 249 | + """_build_model should pass mutated config through kwargs for HF from_pretrained.""" |
| 250 | + quant_cfg, hf_config, result, sentinel_model, mock_init = self._run_build_model_with_native_fp8( |
| 251 | + is_hf_model=True |
| 252 | + ) |
| 253 | + |
| 254 | + assert result is sentinel_model |
| 255 | + assert mock_init.call_args.kwargs["config"] is hf_config |
| 256 | + assert quant_cfg["dequantize"] is True |
| 257 | + |
| 258 | + def test_build_model_custom_fp8_peft_does_not_inject_config_kwarg(self): |
| 259 | + """_build_model should not pass duplicate config kwargs for custom model init.""" |
| 260 | + quant_cfg, _hf_config, result, sentinel_model, mock_init = self._run_build_model_with_native_fp8( |
| 261 | + is_hf_model=False |
| 262 | + ) |
| 263 | + |
| 264 | + assert result is sentinel_model |
| 265 | + assert "config" not in mock_init.call_args.kwargs |
| 266 | + assert quant_cfg["dequantize"] is True |
| 267 | + |
| 268 | + @staticmethod |
| 269 | + def _apply_gate(hf_native_quant_cfg, peft_config, pretrained_path, is_hf_model, hf_config_obj): |
| 270 | + """Replicate the gated kwargs["config"] injection from _build_model.""" |
| 271 | + kwargs: dict = {} |
| 272 | + if _maybe_dequantize_fp8_for_peft(hf_native_quant_cfg, peft_config, pretrained_path): |
| 273 | + if is_hf_model: |
| 274 | + kwargs["config"] = hf_config_obj |
| 275 | + return kwargs |
| 276 | + |
| 277 | + def test_hf_model_fp8_peft_injects_config_kwarg(self): |
| 278 | + """HF path needs config in kwargs so HF.from_pretrained sees the dequantize mutation.""" |
| 279 | + quant_cfg = {"quant_method": "fp8", "dequantize": False} |
| 280 | + hf_config = MagicMock() |
| 281 | + hf_config.quantization_config = quant_cfg |
| 282 | + |
| 283 | + kwargs = self._apply_gate(quant_cfg, MagicMock(), "some-model", is_hf_model=True, hf_config_obj=hf_config) |
| 284 | + |
| 285 | + assert "config" in kwargs |
| 286 | + assert kwargs["config"] is hf_config |
| 287 | + assert quant_cfg["dequantize"] is True |
| 288 | + |
| 289 | + def test_custom_model_fp8_peft_does_not_inject_config_kwarg(self): |
| 290 | + """Custom-model path receives hf_config positionally; injecting config would TypeError (#2164).""" |
| 291 | + quant_cfg = {"quant_method": "fp8", "dequantize": False} |
| 292 | + hf_config = MagicMock() |
| 293 | + hf_config.quantization_config = quant_cfg |
| 294 | + |
| 295 | + kwargs = self._apply_gate(quant_cfg, MagicMock(), "some-model", is_hf_model=False, hf_config_obj=hf_config) |
| 296 | + |
| 297 | + assert "config" not in kwargs |
| 298 | + # Dequantize mutation must still be applied so the custom path sees it via the |
| 299 | + # positional hf_config argument. |
| 300 | + assert quant_cfg["dequantize"] is True |
| 301 | + |
| 302 | + def test_no_peft_does_not_inject_regardless_of_is_hf_model(self): |
| 303 | + """When PEFT is not configured, no injection happens on either path.""" |
| 304 | + quant_cfg = {"quant_method": "fp8", "dequantize": False} |
| 305 | + hf_config = MagicMock() |
| 306 | + |
| 307 | + kwargs_hf = self._apply_gate(quant_cfg, None, "some-model", is_hf_model=True, hf_config_obj=hf_config) |
| 308 | + kwargs_custom = self._apply_gate(quant_cfg, None, "some-model", is_hf_model=False, hf_config_obj=hf_config) |
| 309 | + |
| 310 | + assert "config" not in kwargs_hf |
| 311 | + assert "config" not in kwargs_custom |
| 312 | + assert quant_cfg["dequantize"] is False |
| 313 | + |
| 314 | + def test_non_fp8_quant_does_not_inject(self): |
| 315 | + """Non-FP8 quant configs (e.g. GPTQ) are not the FP8+PEFT case; no injection.""" |
| 316 | + quant_cfg = {"quant_method": "gptq", "bits": 4} |
| 317 | + hf_config = MagicMock() |
| 318 | + |
| 319 | + kwargs = self._apply_gate(quant_cfg, MagicMock(), "some-model", is_hf_model=True, hf_config_obj=hf_config) |
| 320 | + |
| 321 | + assert "config" not in kwargs |
0 commit comments