Describe the bug
Discovered this when working on Gemma4 RL support: NVIDIA-NeMo/RL#2212
Automodel's parallelizer (nemo_automodel/components/distributed/parallelizer.py) explicitly overrides use_cache=False when activation checkpointing is enabled:
if activation_checkpointing:
# Disable KV caching during training to ensure deterministic
# shapes between forward and checkpoint recomputation.
if hasattr(model, "config") and getattr(model.config, "use_cache", None) is not False:
try:
model.config.use_cache = False
except Exception:
pass
This safeguard correctly prevents DynamicCache shape mismatches during gradient recomputation (DynamicCache is stateful — recomputation doubles the cache, causing CheckpointError: Recomputed values have different metadata in NeMo-RL's pipeline. However, for KV-sharing models this override silently breaks correctness: shared layers can no longer retrieve K/V from anchor layers via DynamicCache and fall back to their own untrained k_proj/v_proj weights, producing garbage attention outputs.
Evidence — Automodel SFT on Gemma4 E2B-it with identical config except activation checkpointing:
| Config |
Step 0 loss |
Step 1 loss |
Step 2 loss |
| use_cache=true, activation_checkpointing=false |
8.67 |
4.72 |
2.34 |
| use_cache=true, activation_checkpointing=true |
16.08 |
10.70 |
8.19 |
The ~2x higher loss with activation checkpointing confirms that use_cache was silently forced to False, breaking KV sharing. The model trains without errors but on corrupted forward pass outputs.
Steps/Code to reproduce bug
Run Automodel SFT for two rounds, with activation_checkpointing set to false and true respectively. Compare the loss of the two runs.
- Reproduction logs:
- Automodel SFT without activation checkpointing:
/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/shuangy/src/NeMo-RL/nemo-rl/slurm-10810588.out
- Automodel SFT activation checkpointing test:
/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/shuangy/src/NeMo-RL/nemo-rl/slurm-10811067.out
Expected behavior
This means KV-sharing models currently cannot use activation checkpointing — the parallelizer's safeguard prevents crashes but silently breaks model correctness. Would it be possible to correct this and add the activation checkpointing support by resetting the DynamicCache before recomputation?
Additional context
Add any other context about the problem here.
Describe the bug
Discovered this when working on Gemma4 RL support: NVIDIA-NeMo/RL#2212
Automodel's parallelizer (
nemo_automodel/components/distributed/parallelizer.py) explicitly overridesuse_cache=Falsewhen activation checkpointing is enabled:This safeguard correctly prevents DynamicCache shape mismatches during gradient recomputation (DynamicCache is stateful — recomputation doubles the cache, causing
CheckpointError: Recomputed values have different metadatain NeMo-RL's pipeline. However, for KV-sharing models this override silently breaks correctness: shared layers can no longer retrieve K/V from anchor layers via DynamicCache and fall back to their own untrainedk_proj/v_projweights, producing garbage attention outputs.Evidence — Automodel SFT on Gemma4 E2B-it with identical config except activation checkpointing:
The ~2x higher loss with activation checkpointing confirms that
use_cachewas silently forced toFalse, breaking KV sharing. The model trains without errors but on corrupted forward pass outputs.Steps/Code to reproduce bug
Run Automodel SFT for two rounds, with activation_checkpointing set to false and true respectively. Compare the loss of the two runs.
/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/shuangy/src/NeMo-RL/nemo-rl/slurm-10810588.out/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/shuangy/src/NeMo-RL/nemo-rl/slurm-10811067.outExpected behavior
This means KV-sharing models currently cannot use activation checkpointing — the parallelizer's safeguard prevents crashes but silently breaks model correctness. Would it be possible to correct this and add the activation checkpointing support by resetting the DynamicCache before recomputation?
Additional context
Add any other context about the problem here.