Skip to content

Commit 07123bd

Browse files
adil-aclaude
authored andcommitted
fix: Baichuan2 checkpoint robustness test CI failures (NVIDIA-NeMo#1727)
* fix: checkpoint robustness test CI failures - Add trust_remote_code: true to baichuan ci.checkpoint_robustness - Add hf_device_map_auto: true to nemotron nano configs - Bump robustness global_batch_size 16→32 for multi-node compatibility - Remove hardcoded trust_remote_code=False that broke tokenizer loading - Fix dotted keys in ci.checkpoint_robustness being silently ignored (e.g. distributed.tp_size, dataset.limit_dataset_samples) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: adil-a <adil.asif2000@hotmail.com> * fix: Baichuan2 checkpoint robustness test CI failures - Register MLP-only TP plan for BaichuanForCausalLM (NormHead is not nn.Linear, W_pack has non-interleaved QKV layout — both incompatible with ColwiseParallel) - Fix HF remote code meta-tensor issue: RotaryEmbedding creates inv_freq/cos_cached/sin_cached as plain attributes that stay on meta device; added _fix_meta_rotary_embeddings helper for Phase 4 - Set appropriate KL/loss thresholds for Baichuan2 with TP=2 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: adil-a <adil.asif2000@hotmail.com> * fix: Baichuan2 PEFT checkpoint robustness test CI failures - Apply _fix_meta_rotary_embeddings to PEFT base model loading path - Add KL/loss thresholds to baichuan_2_7b_squad_peft.yaml CI config Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: adil-a <adil.asif2000@hotmail.com> * fix: remove unused cross-TP/resume settings from Baichuan2 PEFT config Cross-TP and resume assertion are skipped for PEFT models in the test. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: adil-a <adil.asif2000@hotmail.com> * fix: add gc.collect() before torch.cuda.empty_cache() in checkpoint robustness test FSDP2/DTensor circular references prevented GPU memory from being freed between test phases, causing OOM on large models (e.g. Nemotron Super 120B) when Phase 4 tries to reload via vanilla HF with device_map="auto". Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: adil-a <adil.asif2000@hotmail.com> * fix: PEFT checkpoint restore for MoE models with activation checkpointing - Strip _checkpoint_wrapped_module. from FQNs in _get_peft_state_dict and _set_peft_state_dict to match DCP's normalization. Without this, expert LoRA weights are silently skipped on reload when activation checkpointing is enabled (keys mismatch), causing KL divergence of ~0.5. - Wire up no_check_hf flag to skip Phase 4 vanilla HF check when configured - Qwen3 MoE 30B LoRA: reduce to 1 node, add no_check_hf Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: adil-a <adil.asif2000@hotmail.com> * fix: Qwen3 MoE PEFT adapter HF compatibility via ParamWrapper format Save Qwen3 MoE expert LoRA adapters in PEFT v0.18+ ParamWrapper format so PeftModel.from_pretrained() can load them directly. Previously, adapters were saved with per-expert individual keys (experts.0.gate_proj.lora_A.weight) which vanilla HF couldn't load because Qwen3 MoE uses fused nn.Parameter tensors (experts.gate_up_proj), not individual nn.Module per expert. The new format (default, v4_compatible=False) uses target_parameters in adapter_config.json and 2D fused LoRA tensors matching ParamWrapper's expected key layout. Legacy per-expert format is preserved when v4_compatible=True. Also: reduce Qwen3 MoE CI from 2 nodes to 1, remove dead no_check_hf parsing from test, clean up _extract_target_modules helpers. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: adil-a <adil.asif2000@hotmail.com> * fix: remove debug print statement from checkpoint robustness test Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: adil-a <adil.asif2000@hotmail.com> --------- Signed-off-by: adil-a <adil.asif2000@hotmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 35b956e commit 07123bd

12 files changed

Lines changed: 309 additions & 54 deletions

File tree

examples/llm_finetune/baichuan/baichuan_2_7b_squad.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,13 @@ ci:
104104
vllm_deploy: true
105105
recipe_owner: adil-a
106106
checkpoint_robustness:
107-
hf_kl_threshold: 5e-3
107+
trust_remote_code: true
108+
kl_threshold: 1e-2
109+
hf_kl_threshold: 5e-2
108110
distributed.tp_size: 2
109111
cross_tp_size: 2
110-
cross_tp_kl_threshold: 5e-3
112+
cross_tp_kl_threshold: 1e-2
113+
resume_loss_threshold: 5e-2
111114
tokenizer_name: baichuan-inc/Baichuan2-7B-Chat
112115
dataset.limit_dataset_samples: 500
113116
validation_dataset.limit_dataset_samples: 500

examples/llm_finetune/baichuan/baichuan_2_7b_squad_peft.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ ci:
121121
vllm_deploy: true
122122
recipe_owner: adil-a
123123
checkpoint_robustness:
124-
hf_kl_threshold: 5e-3
125124
trust_remote_code: true
125+
kl_threshold: 1e-2
126+
hf_kl_threshold: 5e-2
126127
distributed.tp_size: 2
127128
tokenizer_name: baichuan-inc/Baichuan2-7B-Chat
128129
dataset.limit_dataset_samples: 500

examples/llm_finetune/nemotron/nemotron_nano_v3_hellaswag.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ ci:
9595
time: "00:15:00"
9696
checkpoint_robustness:
9797
hf_kl_threshold: 7e-2
98+
hf_device_map_auto: true
9899
experts_implementation: grouped_mm
99100
tokenizer_name: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
100101
no_check_resume: true

examples/llm_finetune/nemotron/nemotron_nano_v3_hellaswag_peft.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ ci:
112112
time: "00:15:00"
113113
checkpoint_robustness:
114114
hf_kl_threshold: 1e-1
115+
hf_device_map_auto: true
115116
experts_implementation: grouped_mm
116117
trust_remote_code: true
117118
tokenizer_name: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16

examples/llm_finetune/qwen/qwen3_moe_30b_lora.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,12 @@ optimizer:
102102
ci:
103103
recipe_owner: adil-a
104104
time: "00:15:00"
105-
nodes: 2
105+
nodes: 1
106106
checkpoint_robustness:
107107
hf_kl_threshold: 7e-2
108108
tokenizer_name: Qwen/Qwen3-30B-A3B
109+
trust_remote_code: true
110+
hf_device_map_auto: true
109111
no_check_resume: true
110112
dataset.num_samples_limit: 500
111113
validation_dataset.num_samples_limit: 500

nemo_automodel/components/checkpoint/addons.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ def pre_save(self, **kwargs) -> None:
155155
model_state = kwargs["model_state"]
156156
peft_config = kwargs["peft_config"]
157157
original_model_path = kwargs["original_model_path"]
158-
hf_peft_config = _get_hf_peft_config(peft_config, model_state)
158+
v4_compatible = kwargs.get("v4_compatible", False)
159+
hf_peft_config = _get_hf_peft_config(peft_config, model_state, v4_compatible=v4_compatible)
159160
automodel_peft_metadata = _get_automodel_peft_metadata(peft_config)
160161
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
161162
# if the HF model has custom model code, we need to save it as part of the checkpoint
@@ -176,13 +177,14 @@ def post_save(self, **kwargs) -> None:
176177
pass
177178

178179

179-
def _get_hf_peft_config(peft_config: "PeftConfig", model_state: ModelState) -> dict:
180+
def _get_hf_peft_config(peft_config: "PeftConfig", model_state: ModelState, v4_compatible: bool = False) -> dict:
180181
"""
181182
Get the minimal PEFT config in the format expected by Hugging Face.
182183
183184
Args:
184185
peft_config: Source PEFT configuration.
185186
model_state: Model wrapper used to infer target modules and model task.
187+
v4_compatible: When True, use legacy per-expert expansion format.
186188
187189
Returns:
188190
A dictionary containing the minimal HF-compatible PEFT configuration
@@ -197,7 +199,8 @@ def _get_hf_peft_config(peft_config: "PeftConfig", model_state: ModelState) -> d
197199
"FeatureExtraction": "FEATURE_EXTRACTION",
198200
}
199201
model_part = model_state.model[0]
200-
target_modules = _extract_target_modules(model_part)
202+
target_modules = _extract_target_modules(model_part, v4_compatible=v4_compatible)
203+
target_parameters = _extract_target_parameters(model_part, v4_compatible=v4_compatible)
201204
try:
202205
arch_name = model_part.config.architectures[0]
203206
# "LlamaForCausalLM".split("For") → ["Llama", "CausalLM"]
@@ -217,7 +220,7 @@ def _get_hf_peft_config(peft_config: "PeftConfig", model_state: ModelState) -> d
217220
except KeyError:
218221
task_type = "CAUSAL_LM"
219222

220-
return {
223+
config = {
221224
"task_type": task_type,
222225
"peft_type": "LORA",
223226
"r": peft_config.dim,
@@ -227,6 +230,9 @@ def _get_hf_peft_config(peft_config: "PeftConfig", model_state: ModelState) -> d
227230
"bias": "none",
228231
"base_model_name_or_path": name_or_path,
229232
}
233+
if target_parameters:
234+
config["target_parameters"] = target_parameters
235+
return config
230236

231237

232238
def _get_automodel_peft_metadata(peft_config: "PeftConfig") -> dict:
@@ -244,28 +250,43 @@ def _get_automodel_peft_metadata(peft_config: "PeftConfig") -> dict:
244250
return {k: v for k, v in peft_config.to_dict().items() if k not in PEFT_KEYS}
245251

246252

247-
def _extract_target_modules(model: nn.Module) -> list[str]:
253+
def _is_qwen3_moe(model: nn.Module) -> bool:
254+
"""Check whether *model* uses the Qwen3 MoE state-dict adapter."""
255+
adapter = getattr(model, "state_dict_adapter", None)
256+
if adapter is None:
257+
return False
258+
from nemo_automodel.components.models.qwen3_moe.state_dict_adapter import Qwen3MoeStateDictAdapter
259+
260+
return isinstance(adapter, Qwen3MoeStateDictAdapter)
261+
262+
263+
def _extract_target_parameters(model: nn.Module, v4_compatible: bool = False) -> list[str]:
264+
"""Extract ``target_parameters`` for PEFT v0.18+ ParamWrapper format.
265+
266+
Returns fused expert parameter paths for Qwen3 MoE when not in legacy mode,
267+
or an empty list otherwise.
248268
"""
249-
Extract the target modules from the model used by LoRA/PEFT layers.
269+
if v4_compatible:
270+
return []
271+
if _is_qwen3_moe(model):
272+
return ["mlp.experts.gate_up_proj", "mlp.experts.down_proj"]
273+
return []
250274

251-
Combined-projection module names (e.g. ``qkv_proj``, ``gate_up_proj``) are
252-
expanded to the individual Hugging Face projection names so that the saved
253-
``adapter_config.json`` is compatible with vLLM, TensorRT-LLM and the
254-
Hugging Face PEFT library.
255275

256-
For MoE expert LoRA (GroupedExpertsLoRA / GroupedExpertsDeepEPLoRA), the
257-
grouped 3-D adapter parameters are expanded to per-expert HF projection
258-
names (e.g. ``model.layers.0.mlp.experts.0.gate_proj``).
276+
def _extract_target_modules(model: nn.Module, v4_compatible: bool = False) -> list[str]:
277+
"""
278+
Extract the target modules from the model used by LoRA/PEFT layers.
259279
260-
Note:
261-
When torch.compile is used, module names get prefixed with `_orig_mod.`.
262-
This function strips those prefixes to get the original module names.
280+
Combined-projection module names (e.g. ``qkv_proj``, ``gate_up_proj``) are
281+
expanded to the individual HF projection names for adapter_config.json
282+
compatibility with vLLM, TensorRT-LLM, and HF PEFT.
263283
264-
Args:
265-
model: The model whose named modules are scanned.
284+
For MoE expert LoRA, grouped 3-D adapter parameters are expanded to
285+
per-expert HF projection names unless the model is Qwen3 MoE in
286+
non-legacy mode (where ``target_parameters`` is used instead).
266287
267-
Returns:
268-
A sorted list of unique module name prefixes that contain LoRA layers.
288+
Strips ``_orig_mod.`` (torch.compile) and ``_checkpoint_wrapped_module.``
289+
(activation checkpointing) prefixes from module names.
269290
"""
270291
# Mapping from combined projection names to their HF-compatible split names.
271292
_COMBINED_TO_SPLIT = {
@@ -278,10 +299,10 @@ def _extract_target_modules(model: nn.Module) -> list[str]:
278299
final_target_modules = set()
279300
for name, _ in model.named_modules():
280301
if "lora" in name.lower():
281-
# Remove the torch.compile _orig_mod prefix if present
282302
target_name = name.rsplit(".", 1)[0]
283303
if target_name.startswith("_orig_mod."):
284304
target_name = target_name[len("_orig_mod.") :]
305+
target_name = target_name.replace("_checkpoint_wrapped_module.", "")
285306

286307
# Expand combined projection names to individual HF projection names
287308
last_component = target_name.rsplit(".", 1)[-1]
@@ -293,13 +314,14 @@ def _extract_target_modules(model: nn.Module) -> list[str]:
293314
else:
294315
final_target_modules.add(target_name)
295316

296-
# Detect MoE expert LoRA: adapter weights stored as nn.Parameter (not
297-
# nn.Module) so they don't appear in named_modules(). Scan parameters
298-
# and expand to per-expert HF projection names.
299-
# Only applies to models that use split-expert state dict conversion
300-
# (MoESplitExpertsStateDictMixin); models with natively merged experts
301-
# (e.g. Qwen 3.5) don't need per-expert expansion.
302-
if hasattr(model, "state_dict_adapter") and isinstance(model.state_dict_adapter, MoESplitExpertsStateDictMixin):
317+
# MoE expert LoRA: adapter weights are nn.Parameter (not nn.Module) so
318+
# they don't appear in named_modules(). Expand to per-expert HF names,
319+
# unless Qwen3 MoE in non-legacy mode (uses target_parameters instead).
320+
_has_split_expert_mixin = hasattr(model, "state_dict_adapter") and isinstance(
321+
model.state_dict_adapter, MoESplitExpertsStateDictMixin
322+
)
323+
_skip_for_qwen3 = not v4_compatible and _is_qwen3_moe(model)
324+
if _has_split_expert_mixin and not _skip_for_qwen3:
303325
seen_expert_groups: set[tuple[str, str]] = set()
304326
for name, param in model.named_parameters():
305327
if not param.requires_grad:
@@ -309,6 +331,7 @@ def _extract_target_modules(model: nn.Module) -> list[str]:
309331
expert_path = name[: -len(f".{lora_suffix}")]
310332
if expert_path.startswith("_orig_mod."):
311333
expert_path = expert_path[len("_orig_mod.") :]
334+
expert_path = expert_path.replace("_checkpoint_wrapped_module.", "")
312335

313336
group = "gate_and_up" if "gate_and_up" in lora_suffix else "down"
314337
if (expert_path, group) in seen_expert_groups:

nemo_automodel/components/checkpoint/checkpointing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,11 @@ def save_model(
290290

291291
# Convert to HF format if using custom model implementations
292292
state_dict = _maybe_adapt_state_dict_to_hf(
293-
model_state.model[0], state_dict, quantization=False, device_mesh=self.moe_mesh
293+
model_state.model[0],
294+
state_dict,
295+
quantization=False,
296+
device_mesh=self.moe_mesh,
297+
v4_compatible=self.config.v4_compatible,
294298
)
295299
# Build the consolidated model.safetensors.index.json if needed
296300
fqn_to_file_index_mapping = self._maybe_build_consolidated_index(model_state, state_dict)

nemo_automodel/components/checkpoint/stateful_wrappers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def _get_peft_state_dict(model: torch.nn.Module) -> dict[str, Any]:
9696
state_dict = {}
9797
for name, param in model.named_parameters():
9898
if param.requires_grad:
99+
# Strip _checkpoint_wrapped_module. from FQNs to match DCP's normalization.
100+
# Without this, activation checkpointing causes key mismatches on reload.
101+
name = name.replace("_checkpoint_wrapped_module.", "")
99102
param = param.full_tensor() if hasattr(param, "full_tensor") else param
100103
state_dict[name] = param.detach().cpu()
101104
return state_dict
@@ -110,7 +113,9 @@ def _set_peft_state_dict(model: torch.nn.Module, state_dict: dict[str, Any]) ->
110113
"""
111114
from torch.distributed.tensor import DTensor, Replicate
112115

113-
param_dict = dict(model.named_parameters())
116+
# Strip _checkpoint_wrapped_module. from FQNs to match DCP's normalization.
117+
# Without this, activation checkpointing causes key mismatches on reload.
118+
param_dict = {name.replace("_checkpoint_wrapped_module.", ""): param for name, param in model.named_parameters()}
114119
loaded, skipped = 0, 0
115120

116121
for name, saved_tensor in state_dict.items():

nemo_automodel/components/distributed/optimized_tp_plans.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
4242
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3ForSequenceClassification
4343

44+
from nemo_automodel.components.models.baichuan.model import BaichuanForCausalLM
4445
from nemo_automodel.components.models.llama.model import LlamaForCausalLM as CustomLlamaForCausalLM
4546
from nemo_automodel.components.models.mistral3.model import Ministral3ForCausalLM
4647
from nemo_automodel.components.models.qwen2.model import Qwen2ForCausalLM as CustomQwen2ForCausalLM
@@ -268,6 +269,27 @@ def get_decilm_nemotron_tp_plan(
268269
return cast(dict[str, ParallelStyle], base_model_tp_plan)
269270

270271

272+
def _parallelize_baichuan(
273+
model: BaichuanForCausalLM | None,
274+
sequence_parallel: bool = False,
275+
) -> dict[str, ParallelStyle]:
276+
"""Parallelizes a BaichuanForCausalLM model (MLP-only).
277+
278+
Only the MLP is sharded. The attention path stays fully replicated
279+
because W_pack uses a non-interleaved [Q|K|V] layout (ColwiseParallel
280+
would split it incorrectly) and NormHead (lm_head) is not nn.Linear
281+
(ColwiseParallel is unsupported).
282+
"""
283+
return cast(
284+
dict[str, ParallelStyle],
285+
{
286+
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
287+
"model.layers.*.mlp.up_proj": ColwiseParallel(),
288+
"model.layers.*.mlp.down_proj": RowwiseParallel(),
289+
},
290+
)
291+
292+
271293
def _parallelize_llama(
272294
model: LlamaForCausalLM | None,
273295
sequence_parallel: bool = False,
@@ -525,6 +547,7 @@ def _get_class_qualname(cls: type) -> str:
525547

526548
# Keyed by qualified class name — see _get_class_qualname for why.
527549
PARALLELIZE_FUNCTIONS: Dict[str, Callable[..., Dict[str, ParallelStyle]]] = {
550+
_get_class_qualname(BaichuanForCausalLM): _parallelize_baichuan,
528551
_get_class_qualname(Qwen2ForCausalLM): _parallelize_qwen,
529552
_get_class_qualname(Qwen3ForCausalLM): _parallelize_qwen,
530553
_get_class_qualname(Qwen3ForSequenceClassification): _parallelize_qwen_classification,

0 commit comments

Comments
 (0)