Skip to content

Commit 79218ca

Browse files
HuiyingLiclaude
andauthored
fix(peft): gate FP8+PEFT config kwarg injection on is_hf_model (#2169)
* fix(peft): gate FP8+PEFT config kwarg injection on is_hf_model When PEFT is requested and the HF config has FP8 quantization, the FP8 dequantize path injected `kwargs["config"] = _hf_config` so that HF's from_pretrained would honor the in-memory `dequantize=True` mutation. But the same kwargs flow into the custom-model branch in model_init.py, which calls `model_cls(hf_config, *model_args, **kwargs)` — passing hf_config positionally — and the duplicate `config` produced: TypeError: MiniMaxM2ForCausalLM.__init__() got multiple values for argument 'config' Gate the kwargs injection on is_hf_model. The dequantize=True flag is mutated in place on the quantization_config dict, so the custom-model path still observes it via the positional hf_config argument; only the HF path needs it surfaced as a kwarg. Fixes: #2164 Signed-off-by: HuiyingLi <willwin.lee@gmail.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * test: cover fp8 peft config kwarg gate Signed-off-by: HuiyingLi <willwin.lee@gmail.com> --------- Signed-off-by: HuiyingLi <willwin.lee@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 42b8788 commit 79218ca

2 files changed

Lines changed: 154 additions & 3 deletions

File tree

nemo_automodel/_transformers/auto_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,12 @@ def _retry(**override):
353353
)
354354
_hf_native_quant_cfg = getattr(_hf_config, "quantization_config", None)
355355
if _maybe_dequantize_fp8_for_peft(_hf_native_quant_cfg, peft_config, pretrained_model_name_or_path_or_config):
356-
kwargs["config"] = _hf_config
356+
# Only HF's from_pretrained needs `config` in kwargs (it would otherwise
357+
# re-read config from disk and lose the in-memory dequantize=True mutation).
358+
# Custom models receive _hf_config positionally in model_init.py and would
359+
# collide with kwargs["config"] (issue #2164).
360+
if is_hf_model:
361+
kwargs["config"] = _hf_config
357362

358363
# Use meta device initialization when:
359364
# - Not using MegatronFSDPManager or DDPManager (they handle their own initialization)

tests/unit_tests/_transformers/test_fp8_peft_dequantize.py

Lines changed: 148 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
"""Tests for FP8 model + PEFT dequantization logic."""
1616

17-
from unittest.mock import MagicMock
17+
from unittest.mock import MagicMock, patch
1818

19-
from nemo_automodel._transformers.auto_model import _maybe_dequantize_fp8_for_peft
19+
from nemo_automodel._transformers.auto_model import _BaseNeMoAutoModelClass, _maybe_dequantize_fp8_for_peft
2020

2121
# ---------------------------------------------------------------------------
2222
# Tests: FP8 + PEFT auto-dequantize
@@ -173,3 +173,149 @@ def test_meta_device_disabled_single_gpu_hf_model(self):
173173
hf_native_quant_cfg=None,
174174
)
175175
assert result is False
176+
177+
178+
# ---------------------------------------------------------------------------
179+
# Tests: kwargs["config"] injection gated on is_hf_model (issue #2164)
180+
# ---------------------------------------------------------------------------
181+
182+
183+
class TestKwargsConfigInjectionGate:
184+
"""Tests for the is_hf_model gate on kwargs["config"] = _hf_config injection.
185+
186+
Custom models receive _hf_config positionally in model_init.py:783 via
187+
model_cls(hf_config, *model_args, **kwargs); injecting config into kwargs
188+
causes a TypeError ("got multiple values for argument 'config'"). The gate
189+
suppresses the injection for the custom-model path while preserving the
190+
in-place dequantize=True mutation needed by the HF path.
191+
"""
192+
193+
@staticmethod
194+
def _make_build_kwargs(is_hf_model):
195+
"""Minimal kwargs for running _build_model through the FP8+PEFT gate."""
196+
mesh = MagicMock()
197+
mesh.tp_size = 1
198+
mesh.cp_size = 1
199+
return dict(
200+
is_hf_model=is_hf_model,
201+
use_liger_kernel=False,
202+
use_sdpa_patching=False,
203+
sdpa_method=None,
204+
torch_dtype="auto",
205+
attn_implementation="eager",
206+
quantization_config=None,
207+
force_hf=False,
208+
model_wrapper=None,
209+
autopipeline=None,
210+
parallelize_fn=None,
211+
qat_quantizer=None,
212+
mesh=mesh,
213+
loss_fn=None,
214+
peft_config=MagicMock(),
215+
fp8_config=None,
216+
compile_config=None,
217+
load_base_model=True,
218+
)
219+
220+
@staticmethod
221+
def _run_build_model_with_native_fp8(is_hf_model):
222+
quant_cfg = {"quant_method": "fp8", "dequantize": False}
223+
hf_config = MagicMock()
224+
hf_config.quantization_config = quant_cfg
225+
sentinel_model = MagicMock()
226+
227+
with (
228+
patch("nemo_automodel._transformers.auto_model._apply_preload_overrides", return_value=("eager", False)),
229+
patch("nemo_automodel._transformers.auto_model.get_hf_config", return_value=hf_config),
230+
patch("nemo_automodel._transformers.auto_model._init_model") as mock_init,
231+
patch("nemo_automodel._transformers.auto_model.get_world_size_safe", return_value=1),
232+
patch("nemo_automodel._transformers.auto_model._verify_sdpa_support"),
233+
patch(
234+
"nemo_automodel._transformers.capabilities.attach_capabilities_and_validate",
235+
return_value=sentinel_model,
236+
),
237+
patch("nemo_automodel._transformers.auto_model.apply_model_infrastructure", return_value=sentinel_model),
238+
patch("torch.cuda.current_device", return_value=0),
239+
):
240+
mock_init.return_value = (not is_hf_model, sentinel_model)
241+
result = _BaseNeMoAutoModelClass._build_model(
242+
"some-model",
243+
**TestKwargsConfigInjectionGate._make_build_kwargs(is_hf_model),
244+
)
245+
246+
return quant_cfg, hf_config, result, sentinel_model, mock_init
247+
248+
def test_build_model_hf_fp8_peft_injects_config_kwarg(self):
249+
"""_build_model should pass mutated config through kwargs for HF from_pretrained."""
250+
quant_cfg, hf_config, result, sentinel_model, mock_init = self._run_build_model_with_native_fp8(
251+
is_hf_model=True
252+
)
253+
254+
assert result is sentinel_model
255+
assert mock_init.call_args.kwargs["config"] is hf_config
256+
assert quant_cfg["dequantize"] is True
257+
258+
def test_build_model_custom_fp8_peft_does_not_inject_config_kwarg(self):
259+
"""_build_model should not pass duplicate config kwargs for custom model init."""
260+
quant_cfg, _hf_config, result, sentinel_model, mock_init = self._run_build_model_with_native_fp8(
261+
is_hf_model=False
262+
)
263+
264+
assert result is sentinel_model
265+
assert "config" not in mock_init.call_args.kwargs
266+
assert quant_cfg["dequantize"] is True
267+
268+
@staticmethod
269+
def _apply_gate(hf_native_quant_cfg, peft_config, pretrained_path, is_hf_model, hf_config_obj):
270+
"""Replicate the gated kwargs["config"] injection from _build_model."""
271+
kwargs: dict = {}
272+
if _maybe_dequantize_fp8_for_peft(hf_native_quant_cfg, peft_config, pretrained_path):
273+
if is_hf_model:
274+
kwargs["config"] = hf_config_obj
275+
return kwargs
276+
277+
def test_hf_model_fp8_peft_injects_config_kwarg(self):
278+
"""HF path needs config in kwargs so HF.from_pretrained sees the dequantize mutation."""
279+
quant_cfg = {"quant_method": "fp8", "dequantize": False}
280+
hf_config = MagicMock()
281+
hf_config.quantization_config = quant_cfg
282+
283+
kwargs = self._apply_gate(quant_cfg, MagicMock(), "some-model", is_hf_model=True, hf_config_obj=hf_config)
284+
285+
assert "config" in kwargs
286+
assert kwargs["config"] is hf_config
287+
assert quant_cfg["dequantize"] is True
288+
289+
def test_custom_model_fp8_peft_does_not_inject_config_kwarg(self):
290+
"""Custom-model path receives hf_config positionally; injecting config would TypeError (#2164)."""
291+
quant_cfg = {"quant_method": "fp8", "dequantize": False}
292+
hf_config = MagicMock()
293+
hf_config.quantization_config = quant_cfg
294+
295+
kwargs = self._apply_gate(quant_cfg, MagicMock(), "some-model", is_hf_model=False, hf_config_obj=hf_config)
296+
297+
assert "config" not in kwargs
298+
# Dequantize mutation must still be applied so the custom path sees it via the
299+
# positional hf_config argument.
300+
assert quant_cfg["dequantize"] is True
301+
302+
def test_no_peft_does_not_inject_regardless_of_is_hf_model(self):
303+
"""When PEFT is not configured, no injection happens on either path."""
304+
quant_cfg = {"quant_method": "fp8", "dequantize": False}
305+
hf_config = MagicMock()
306+
307+
kwargs_hf = self._apply_gate(quant_cfg, None, "some-model", is_hf_model=True, hf_config_obj=hf_config)
308+
kwargs_custom = self._apply_gate(quant_cfg, None, "some-model", is_hf_model=False, hf_config_obj=hf_config)
309+
310+
assert "config" not in kwargs_hf
311+
assert "config" not in kwargs_custom
312+
assert quant_cfg["dequantize"] is False
313+
314+
def test_non_fp8_quant_does_not_inject(self):
315+
"""Non-FP8 quant configs (e.g. GPTQ) are not the FP8+PEFT case; no injection."""
316+
quant_cfg = {"quant_method": "gptq", "bits": 4}
317+
hf_config = MagicMock()
318+
319+
kwargs = self._apply_gate(quant_cfg, MagicMock(), "some-model", is_hf_model=True, hf_config_obj=hf_config)
320+
321+
assert "config" not in kwargs

0 commit comments

Comments
 (0)