Skip to content

Commit aa9574c

Browse files
athittenAbhishree Thittenamaneroot
authored andcommitted
fix: Allow use_cache when activation_checkpointing is True (NVIDIA-NeMo#1726)
* Allow use_cache when activation_checkpointing is True Signed-off-by: Abhishree Thittenamane <athittenaman@cw-dfw-cs-001-login-02.cm.cluster> * Fix lint errors Signed-off-by: Abhishree <abhishreetm@gmail.com> * Fix imports linting Signed-off-by: Abhishree <abhishreetm@gmail.com> * Add tests Signed-off-by: Abhishree Thittenamane <athittenaman@cw-dfw-cs-001-login-02.cm.cluster> * Add more tests Signed-off-by: root <root@pool0-01595.cm.cluster> --------- Signed-off-by: Abhishree Thittenamane <athittenaman@cw-dfw-cs-001-login-02.cm.cluster> Signed-off-by: Abhishree <abhishreetm@gmail.com> Signed-off-by: root <root@pool0-01595.cm.cluster> Co-authored-by: Abhishree Thittenamane <athittenaman@cw-dfw-cs-001-login-02.cm.cluster> Co-authored-by: root <root@pool0-01595.cm.cluster>
1 parent 896efb2 commit aa9574c

File tree

2 files changed

+353
-8
lines changed

2 files changed

+353
-8
lines changed

nemo_automodel/components/distributed/parallelizer.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,19 @@ def parallelize(
185185

186186
# Apply activation checkpointing to linear layers if requested
187187
if activation_checkpointing:
188-
# Disable KV caching during training to ensure deterministic
189-
# shapes between forward and checkpoint recomputation.
190-
if hasattr(model, "config") and getattr(model.config, "use_cache", None) is not False:
191-
try:
192-
model.config.use_cache = False
193-
except Exception:
194-
pass
188+
# Models with KV-shared layers (e.g. Gemma4 2B/4B) pass K/V from
189+
# earlier layers to later layers through the DynamicCache. Disabling
190+
# the cache breaks this architectural dependency, so we must keep
191+
# use_cache=True for those models.
192+
_text_cfg = getattr(getattr(model, "config", None), "text_config", None) or getattr(model, "config", None)
193+
_has_kv_sharing = getattr(_text_cfg, "num_kv_shared_layers", 0) > 0
194+
195+
if not _has_kv_sharing:
196+
if hasattr(model, "config") and getattr(model.config, "use_cache", None) is not False:
197+
try:
198+
model.config.use_cache = False
199+
except Exception:
200+
pass
195201

196202
# For HF-native models in transformers >= 5.3.0, GradientCheckpointingLayer.__call__
197203
# applies torch.utils.checkpoint at full-layer granularity when gradient_checkpointing=True.
@@ -219,7 +225,10 @@ def parallelize(
219225
for i, layer in enumerate(layers):
220226
if hasattr(layer, "mlp"):
221227
layers[i].mlp = checkpoint_wrapper(layer.mlp)
222-
if hasattr(layer, "self_attn"):
228+
# Skip self_attn checkpointing for KV-shared models:
229+
# recomputation would double-write to the DynamicCache,
230+
# corrupting K/V entries that shared layers depend on.
231+
if hasattr(layer, "self_attn") and not _has_kv_sharing:
223232
layers[i].self_attn = checkpoint_wrapper(layers[i].self_attn) # type: ignore
224233

225234
if hasattr(layer, "input_layernorm"):

tests/unit_tests/distributed/test_parallelizer.py

Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,3 +1169,339 @@ def test_no_attn_keys_is_not_sharded(self):
11691169

11701170
def test_empty_plan_is_not_sharded(self):
11711171
assert _attention_is_head_sharded({}) is False
1172+
1173+
1174+
# ---------------------------------------------------------------------------
1175+
# Activation checkpointing + KV-sharing tests
1176+
# ---------------------------------------------------------------------------
1177+
1178+
1179+
class _FakeLayer(nn.Module):
1180+
"""Minimal transformer layer with mlp, self_attn, and layernorms."""
1181+
1182+
def __init__(self, dim: int = 16):
1183+
super().__init__()
1184+
self.mlp = nn.Linear(dim, dim)
1185+
self.self_attn = nn.Linear(dim, dim)
1186+
self.input_layernorm = nn.Linear(dim, dim)
1187+
self.post_attention_layernorm = nn.Linear(dim, dim)
1188+
1189+
def forward(self, x):
1190+
return x
1191+
1192+
1193+
def _make_model_for_ac(
1194+
num_layers: int = 2,
1195+
dim: int = 16,
1196+
use_cache: bool = True,
1197+
num_kv_shared_layers: int = 0,
1198+
text_config_nested: bool = True,
1199+
):
1200+
"""Build a minimal model with configurable KV-sharing for activation-checkpointing tests.
1201+
1202+
Args:
1203+
text_config_nested: If True, place ``num_kv_shared_layers`` under
1204+
``config.text_config`` (VLM pattern). If False, place it directly
1205+
on ``config`` (flat LLM pattern).
1206+
"""
1207+
1208+
class _Inner(nn.Module):
1209+
def __init__(self):
1210+
super().__init__()
1211+
self.layers = nn.ModuleList([_FakeLayer(dim) for _ in range(num_layers)])
1212+
1213+
model = nn.Module()
1214+
model.model = _Inner() # type: ignore[attr-defined]
1215+
1216+
if text_config_nested:
1217+
text_cfg = SimpleNamespace(num_kv_shared_layers=num_kv_shared_layers)
1218+
model.config = SimpleNamespace(use_cache=use_cache, text_config=text_cfg) # type: ignore[attr-defined]
1219+
else:
1220+
model.config = SimpleNamespace( # type: ignore[attr-defined]
1221+
use_cache=use_cache,
1222+
num_kv_shared_layers=num_kv_shared_layers,
1223+
)
1224+
model.forward = lambda x: x # type: ignore[attr-defined]
1225+
return model
1226+
1227+
1228+
class TestActivationCheckpointingKVSharing:
1229+
"""Tests for the KV-sharing–aware activation-checkpointing guards
1230+
in ``DefaultParallelizationStrategy.parallelize``.
1231+
"""
1232+
1233+
@pytest.fixture(autouse=True)
1234+
def _patch_parallelizer(self, monkeypatch):
1235+
"""Patch heavy distributed primitives so we can call ``parallelize``
1236+
without a real GPU mesh. ``checkpoint_wrapper`` is replaced with a
1237+
lightweight wrapper that records which module was wrapped.
1238+
"""
1239+
1240+
class _Wrapped(nn.Module):
1241+
"""Sentinel wrapper so we can assert which sub-modules were checkpointed.
1242+
1243+
Must inherit from ``nn.Module`` because PyTorch's ``__setattr__``
1244+
rejects non-Module values when replacing a registered child module.
1245+
"""
1246+
1247+
def __init__(self, inner):
1248+
super().__init__()
1249+
self._inner = inner
1250+
1251+
def forward(self, x):
1252+
return self._inner(x)
1253+
1254+
self._Wrapped = _Wrapped
1255+
1256+
monkeypatch.setattr(
1257+
"nemo_automodel.components.distributed.parallelizer.checkpoint_wrapper",
1258+
_Wrapped,
1259+
)
1260+
monkeypatch.setattr(
1261+
"nemo_automodel.components.distributed.parallelizer.fully_shard",
1262+
lambda model, **kw: model,
1263+
)
1264+
monkeypatch.setattr(
1265+
"nemo_automodel.components.distributed.parallelizer.apply_fsdp2_sharding_recursively",
1266+
lambda *a, **kw: None,
1267+
)
1268+
monkeypatch.setattr(
1269+
"nemo_automodel.components.distributed.parallelizer.get_submesh",
1270+
lambda mesh, names: MagicMock(),
1271+
)
1272+
1273+
def _run_parallelize(self, model, activation_checkpointing=True):
1274+
"""Invoke the strategy under test and return the model."""
1275+
from nemo_automodel.components.distributed.parallelizer import DefaultParallelizationStrategy
1276+
1277+
strategy = DefaultParallelizationStrategy()
1278+
mesh = MagicMock(spec=DeviceMesh)
1279+
tp_mesh = MagicMock()
1280+
tp_mesh.size.return_value = 1 # no TP
1281+
mesh.__getitem__ = lambda self_, key: tp_mesh
1282+
return strategy.parallelize(
1283+
model=model,
1284+
device_mesh=mesh,
1285+
activation_checkpointing=activation_checkpointing,
1286+
)
1287+
1288+
# ------------------------------------------------------------------ #
1289+
# use_cache preservation / disabling
1290+
# ------------------------------------------------------------------ #
1291+
1292+
def test_use_cache_preserved_when_kv_sharing(self):
1293+
"""Models with num_kv_shared_layers > 0 must keep use_cache=True."""
1294+
model = _make_model_for_ac(use_cache=True, num_kv_shared_layers=20)
1295+
self._run_parallelize(model)
1296+
assert model.config.use_cache is True
1297+
1298+
def test_use_cache_disabled_without_kv_sharing(self):
1299+
"""Standard models (num_kv_shared_layers=0) get use_cache=False."""
1300+
model = _make_model_for_ac(use_cache=True, num_kv_shared_layers=0)
1301+
self._run_parallelize(model)
1302+
assert model.config.use_cache is False
1303+
1304+
def test_use_cache_preserved_flat_config(self):
1305+
"""KV-sharing detected through a flat config (no text_config nesting)."""
1306+
model = _make_model_for_ac(
1307+
use_cache=True, num_kv_shared_layers=10, text_config_nested=False
1308+
)
1309+
self._run_parallelize(model)
1310+
assert model.config.use_cache is True
1311+
1312+
def test_use_cache_disabled_flat_config_no_sharing(self):
1313+
"""Flat config without KV sharing still disables cache."""
1314+
model = _make_model_for_ac(
1315+
use_cache=True, num_kv_shared_layers=0, text_config_nested=False
1316+
)
1317+
self._run_parallelize(model)
1318+
assert model.config.use_cache is False
1319+
1320+
def test_use_cache_noop_when_already_false(self):
1321+
"""If use_cache is already False and no KV sharing, code path is a no-op."""
1322+
model = _make_model_for_ac(use_cache=False, num_kv_shared_layers=0)
1323+
self._run_parallelize(model)
1324+
assert model.config.use_cache is False
1325+
1326+
def test_no_config_does_not_crash(self, monkeypatch):
1327+
"""Model without a config attribute must not raise."""
1328+
monkeypatch.setattr(
1329+
"nemo_automodel.components.distributed.parallelizer._extract_model_layers",
1330+
lambda m: [],
1331+
)
1332+
model = nn.Module()
1333+
model.forward = lambda x: x # type: ignore[attr-defined]
1334+
# no model.config at all
1335+
self._run_parallelize(model) # should not raise
1336+
1337+
# ------------------------------------------------------------------ #
1338+
# self_attn checkpoint wrapping
1339+
# ------------------------------------------------------------------ #
1340+
1341+
def test_self_attn_not_wrapped_when_kv_sharing(self):
1342+
"""KV-shared models: self_attn must NOT be wrapped (would corrupt cache)."""
1343+
model = _make_model_for_ac(use_cache=True, num_kv_shared_layers=20)
1344+
self._run_parallelize(model)
1345+
for layer in model.model.layers:
1346+
assert not isinstance(layer.self_attn, self._Wrapped), (
1347+
"self_attn should NOT be checkpoint-wrapped for KV-shared models"
1348+
)
1349+
1350+
def test_self_attn_wrapped_without_kv_sharing(self):
1351+
"""Standard models: self_attn IS wrapped."""
1352+
model = _make_model_for_ac(use_cache=True, num_kv_shared_layers=0)
1353+
self._run_parallelize(model)
1354+
for layer in model.model.layers:
1355+
assert isinstance(layer.self_attn, self._Wrapped), (
1356+
"self_attn should be checkpoint-wrapped for standard models"
1357+
)
1358+
1359+
def test_mlp_always_wrapped(self):
1360+
"""MLP is checkpoint-wrapped regardless of KV sharing."""
1361+
for kv_shared in (0, 20):
1362+
model = _make_model_for_ac(num_kv_shared_layers=kv_shared)
1363+
self._run_parallelize(model)
1364+
for layer in model.model.layers:
1365+
assert isinstance(layer.mlp, self._Wrapped), (
1366+
f"mlp should always be wrapped (num_kv_shared_layers={kv_shared})"
1367+
)
1368+
1369+
def test_layernorms_always_wrapped(self):
1370+
"""Layernorms are checkpoint-wrapped regardless of KV sharing."""
1371+
for kv_shared in (0, 20):
1372+
model = _make_model_for_ac(num_kv_shared_layers=kv_shared)
1373+
self._run_parallelize(model)
1374+
for layer in model.model.layers:
1375+
assert isinstance(layer.input_layernorm, self._Wrapped)
1376+
assert isinstance(layer.post_attention_layernorm, self._Wrapped)
1377+
1378+
def test_no_wrapping_without_activation_checkpointing(self):
1379+
"""When activation_checkpointing=False, nothing is wrapped."""
1380+
model = _make_model_for_ac(num_kv_shared_layers=0)
1381+
self._run_parallelize(model, activation_checkpointing=False)
1382+
for layer in model.model.layers:
1383+
assert not isinstance(layer.mlp, self._Wrapped)
1384+
assert not isinstance(layer.self_attn, self._Wrapped)
1385+
assert model.config.use_cache is True # untouched
1386+
1387+
# ------------------------------------------------------------------ #
1388+
# HF native gradient-checkpointing path
1389+
# ------------------------------------------------------------------ #
1390+
1391+
# ------------------------------------------------------------------ #
1392+
# Exception / edge-case branches
1393+
# ------------------------------------------------------------------ #
1394+
1395+
def test_frozen_config_use_cache_except_branch(self):
1396+
"""When ``model.config.use_cache = False`` raises, the except branch runs."""
1397+
model = _make_model_for_ac(use_cache=True, num_kv_shared_layers=0)
1398+
1399+
class _FrozenConfig:
1400+
use_cache = True
1401+
text_config = SimpleNamespace(num_kv_shared_layers=0)
1402+
1403+
def __setattr__(self, name, value):
1404+
raise AttributeError("frozen")
1405+
1406+
model.config = _FrozenConfig() # type: ignore[attr-defined]
1407+
self._run_parallelize(model)
1408+
# use_cache stays True because the assignment raised and was caught
1409+
assert model.config.use_cache is True
1410+
1411+
def test_no_config_with_layers_does_not_crash(self):
1412+
"""Model without ``config`` but with extractable layers does not crash."""
1413+
1414+
class _Bare(nn.Module):
1415+
def __init__(self):
1416+
super().__init__()
1417+
self.model = nn.Module()
1418+
self.model.layers = nn.ModuleList([_FakeLayer() for _ in range(2)]) # type: ignore[attr-defined]
1419+
1420+
def forward(self, x):
1421+
return x
1422+
1423+
model = _Bare()
1424+
# no model.config → hasattr(model, "config") is False
1425+
self._run_parallelize(model)
1426+
# mlp should still be wrapped (activation_checkpointing still applies)
1427+
for layer in model.model.layers:
1428+
assert isinstance(layer.mlp, self._Wrapped)
1429+
1430+
def test_layer_missing_self_attn(self):
1431+
"""Layers without ``self_attn`` are skipped gracefully."""
1432+
1433+
class _MlpOnlyLayer(nn.Module):
1434+
def __init__(self):
1435+
super().__init__()
1436+
self.mlp = nn.Linear(16, 16)
1437+
1438+
def forward(self, x):
1439+
return x
1440+
1441+
model = _make_model_for_ac(num_kv_shared_layers=0)
1442+
model.model.layers = nn.ModuleList([_MlpOnlyLayer() for _ in range(2)])
1443+
self._run_parallelize(model)
1444+
for layer in model.model.layers:
1445+
assert isinstance(layer.mlp, self._Wrapped)
1446+
assert not hasattr(layer, "self_attn")
1447+
1448+
def test_layer_missing_mlp(self):
1449+
"""Layers without ``mlp`` are skipped gracefully."""
1450+
1451+
class _AttnOnlyLayer(nn.Module):
1452+
def __init__(self):
1453+
super().__init__()
1454+
self.self_attn = nn.Linear(16, 16)
1455+
1456+
def forward(self, x):
1457+
return x
1458+
1459+
model = _make_model_for_ac(num_kv_shared_layers=0)
1460+
model.model.layers = nn.ModuleList([_AttnOnlyLayer() for _ in range(2)])
1461+
self._run_parallelize(model)
1462+
for layer in model.model.layers:
1463+
assert isinstance(layer.self_attn, self._Wrapped)
1464+
assert not hasattr(layer, "mlp")
1465+
1466+
# ------------------------------------------------------------------ #
1467+
# HF native gradient-checkpointing path
1468+
# ------------------------------------------------------------------ #
1469+
1470+
@staticmethod
1471+
def _setup_hf_native_model(monkeypatch, num_kv_shared_layers):
1472+
"""Helper: configure a model + fake transformers module for the HF native path."""
1473+
import types
1474+
1475+
class _FakeGradLayer(_FakeLayer):
1476+
pass
1477+
1478+
_FakeGradLayer.__module__ = "transformers.models.gemma4.modeling_gemma4"
1479+
1480+
fake_module = types.ModuleType("transformers.modeling_layers")
1481+
fake_module.GradientCheckpointingLayer = _FakeGradLayer # type: ignore[attr-defined]
1482+
monkeypatch.setitem(sys.modules, "transformers.modeling_layers", fake_module)
1483+
1484+
model = _make_model_for_ac(use_cache=True, num_kv_shared_layers=num_kv_shared_layers)
1485+
for i in range(len(model.model.layers)):
1486+
model.model.layers[i] = _FakeGradLayer()
1487+
model.supports_gradient_checkpointing = True # type: ignore[attr-defined]
1488+
model.gradient_checkpointing_enable = MagicMock() # type: ignore[attr-defined]
1489+
return model
1490+
1491+
def test_hf_native_grad_ckpt_preserves_use_cache_with_kv_sharing(self, monkeypatch):
1492+
"""Even when the HF native path is taken, use_cache stays True for KV-shared models."""
1493+
model = self._setup_hf_native_model(monkeypatch, num_kv_shared_layers=20)
1494+
self._run_parallelize(model)
1495+
1496+
assert model.config.use_cache is True
1497+
model.gradient_checkpointing_enable.assert_called_once()
1498+
1499+
def test_hf_native_grad_ckpt_disables_use_cache_without_kv_sharing(self, monkeypatch):
1500+
"""HF native path + no KV sharing: use_cache is set to False."""
1501+
model = self._setup_hf_native_model(monkeypatch, num_kv_shared_layers=0)
1502+
self._run_parallelize(model)
1503+
1504+
assert model.config.use_cache is False
1505+
model.gradient_checkpointing_enable.assert_called_once_with(
1506+
gradient_checkpointing_kwargs={"use_reentrant": True}
1507+
)

0 commit comments

Comments
 (0)