Skip to content

Correctness issue when Gemma4 E2B/E4B models (KV-sharing models) training has activation_checkpointing enabled #1705

@sharonyu-115

Description

@sharonyu-115

Describe the bug

Discovered this when working on Gemma4 RL support: NVIDIA-NeMo/RL#2212
Automodel's parallelizer (nemo_automodel/components/distributed/parallelizer.py) explicitly overrides use_cache=False when activation checkpointing is enabled:

if activation_checkpointing:
    # Disable KV caching during training to ensure deterministic
    # shapes between forward and checkpoint recomputation.
    if hasattr(model, "config") and getattr(model.config, "use_cache", None) is not False:
        try:
            model.config.use_cache = False
        except Exception:
            pass

This safeguard correctly prevents DynamicCache shape mismatches during gradient recomputation (DynamicCache is stateful — recomputation doubles the cache, causing CheckpointError: Recomputed values have different metadata in NeMo-RL's pipeline. However, for KV-sharing models this override silently breaks correctness: shared layers can no longer retrieve K/V from anchor layers via DynamicCache and fall back to their own untrained k_proj/v_proj weights, producing garbage attention outputs.

Evidence — Automodel SFT on Gemma4 E2B-it with identical config except activation checkpointing:

Config Step 0 loss Step 1 loss Step 2 loss
use_cache=true, activation_checkpointing=false 8.67 4.72 2.34
use_cache=true, activation_checkpointing=true 16.08 10.70 8.19

The ~2x higher loss with activation checkpointing confirms that use_cache was silently forced to False, breaking KV sharing. The model trains without errors but on corrupted forward pass outputs.

Steps/Code to reproduce bug

Run Automodel SFT for two rounds, with activation_checkpointing set to false and true respectively. Compare the loss of the two runs.

  • Reproduction logs:
    • Automodel SFT without activation checkpointing: /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/shuangy/src/NeMo-RL/nemo-rl/slurm-10810588.out
    • Automodel SFT activation checkpointing test: /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/shuangy/src/NeMo-RL/nemo-rl/slurm-10811067.out

Expected behavior

This means KV-sharing models currently cannot use activation checkpointing — the parallelizer's safeguard prevents crashes but silently breaks model correctness. Would it be possible to correct this and add the activation checkpointing support by resetting the DynamicCache before recomputation?

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions