@@ -1169,3 +1169,339 @@ def test_no_attn_keys_is_not_sharded(self):
11691169
11701170 def test_empty_plan_is_not_sharded (self ):
11711171 assert _attention_is_head_sharded ({}) is False
1172+
1173+
1174+ # ---------------------------------------------------------------------------
1175+ # Activation checkpointing + KV-sharing tests
1176+ # ---------------------------------------------------------------------------
1177+
1178+
1179+ class _FakeLayer (nn .Module ):
1180+ """Minimal transformer layer with mlp, self_attn, and layernorms."""
1181+
1182+ def __init__ (self , dim : int = 16 ):
1183+ super ().__init__ ()
1184+ self .mlp = nn .Linear (dim , dim )
1185+ self .self_attn = nn .Linear (dim , dim )
1186+ self .input_layernorm = nn .Linear (dim , dim )
1187+ self .post_attention_layernorm = nn .Linear (dim , dim )
1188+
1189+ def forward (self , x ):
1190+ return x
1191+
1192+
1193+ def _make_model_for_ac (
1194+ num_layers : int = 2 ,
1195+ dim : int = 16 ,
1196+ use_cache : bool = True ,
1197+ num_kv_shared_layers : int = 0 ,
1198+ text_config_nested : bool = True ,
1199+ ):
1200+ """Build a minimal model with configurable KV-sharing for activation-checkpointing tests.
1201+
1202+ Args:
1203+ text_config_nested: If True, place ``num_kv_shared_layers`` under
1204+ ``config.text_config`` (VLM pattern). If False, place it directly
1205+ on ``config`` (flat LLM pattern).
1206+ """
1207+
1208+ class _Inner (nn .Module ):
1209+ def __init__ (self ):
1210+ super ().__init__ ()
1211+ self .layers = nn .ModuleList ([_FakeLayer (dim ) for _ in range (num_layers )])
1212+
1213+ model = nn .Module ()
1214+ model .model = _Inner () # type: ignore[attr-defined]
1215+
1216+ if text_config_nested :
1217+ text_cfg = SimpleNamespace (num_kv_shared_layers = num_kv_shared_layers )
1218+ model .config = SimpleNamespace (use_cache = use_cache , text_config = text_cfg ) # type: ignore[attr-defined]
1219+ else :
1220+ model .config = SimpleNamespace ( # type: ignore[attr-defined]
1221+ use_cache = use_cache ,
1222+ num_kv_shared_layers = num_kv_shared_layers ,
1223+ )
1224+ model .forward = lambda x : x # type: ignore[attr-defined]
1225+ return model
1226+
1227+
1228+ class TestActivationCheckpointingKVSharing :
1229+ """Tests for the KV-sharing–aware activation-checkpointing guards
1230+ in ``DefaultParallelizationStrategy.parallelize``.
1231+ """
1232+
1233+ @pytest .fixture (autouse = True )
1234+ def _patch_parallelizer (self , monkeypatch ):
1235+ """Patch heavy distributed primitives so we can call ``parallelize``
1236+ without a real GPU mesh. ``checkpoint_wrapper`` is replaced with a
1237+ lightweight wrapper that records which module was wrapped.
1238+ """
1239+
1240+ class _Wrapped (nn .Module ):
1241+ """Sentinel wrapper so we can assert which sub-modules were checkpointed.
1242+
1243+ Must inherit from ``nn.Module`` because PyTorch's ``__setattr__``
1244+ rejects non-Module values when replacing a registered child module.
1245+ """
1246+
1247+ def __init__ (self , inner ):
1248+ super ().__init__ ()
1249+ self ._inner = inner
1250+
1251+ def forward (self , x ):
1252+ return self ._inner (x )
1253+
1254+ self ._Wrapped = _Wrapped
1255+
1256+ monkeypatch .setattr (
1257+ "nemo_automodel.components.distributed.parallelizer.checkpoint_wrapper" ,
1258+ _Wrapped ,
1259+ )
1260+ monkeypatch .setattr (
1261+ "nemo_automodel.components.distributed.parallelizer.fully_shard" ,
1262+ lambda model , ** kw : model ,
1263+ )
1264+ monkeypatch .setattr (
1265+ "nemo_automodel.components.distributed.parallelizer.apply_fsdp2_sharding_recursively" ,
1266+ lambda * a , ** kw : None ,
1267+ )
1268+ monkeypatch .setattr (
1269+ "nemo_automodel.components.distributed.parallelizer.get_submesh" ,
1270+ lambda mesh , names : MagicMock (),
1271+ )
1272+
1273+ def _run_parallelize (self , model , activation_checkpointing = True ):
1274+ """Invoke the strategy under test and return the model."""
1275+ from nemo_automodel .components .distributed .parallelizer import DefaultParallelizationStrategy
1276+
1277+ strategy = DefaultParallelizationStrategy ()
1278+ mesh = MagicMock (spec = DeviceMesh )
1279+ tp_mesh = MagicMock ()
1280+ tp_mesh .size .return_value = 1 # no TP
1281+ mesh .__getitem__ = lambda self_ , key : tp_mesh
1282+ return strategy .parallelize (
1283+ model = model ,
1284+ device_mesh = mesh ,
1285+ activation_checkpointing = activation_checkpointing ,
1286+ )
1287+
1288+ # ------------------------------------------------------------------ #
1289+ # use_cache preservation / disabling
1290+ # ------------------------------------------------------------------ #
1291+
1292+ def test_use_cache_preserved_when_kv_sharing (self ):
1293+ """Models with num_kv_shared_layers > 0 must keep use_cache=True."""
1294+ model = _make_model_for_ac (use_cache = True , num_kv_shared_layers = 20 )
1295+ self ._run_parallelize (model )
1296+ assert model .config .use_cache is True
1297+
1298+ def test_use_cache_disabled_without_kv_sharing (self ):
1299+ """Standard models (num_kv_shared_layers=0) get use_cache=False."""
1300+ model = _make_model_for_ac (use_cache = True , num_kv_shared_layers = 0 )
1301+ self ._run_parallelize (model )
1302+ assert model .config .use_cache is False
1303+
1304+ def test_use_cache_preserved_flat_config (self ):
1305+ """KV-sharing detected through a flat config (no text_config nesting)."""
1306+ model = _make_model_for_ac (
1307+ use_cache = True , num_kv_shared_layers = 10 , text_config_nested = False
1308+ )
1309+ self ._run_parallelize (model )
1310+ assert model .config .use_cache is True
1311+
1312+ def test_use_cache_disabled_flat_config_no_sharing (self ):
1313+ """Flat config without KV sharing still disables cache."""
1314+ model = _make_model_for_ac (
1315+ use_cache = True , num_kv_shared_layers = 0 , text_config_nested = False
1316+ )
1317+ self ._run_parallelize (model )
1318+ assert model .config .use_cache is False
1319+
1320+ def test_use_cache_noop_when_already_false (self ):
1321+ """If use_cache is already False and no KV sharing, code path is a no-op."""
1322+ model = _make_model_for_ac (use_cache = False , num_kv_shared_layers = 0 )
1323+ self ._run_parallelize (model )
1324+ assert model .config .use_cache is False
1325+
1326+ def test_no_config_does_not_crash (self , monkeypatch ):
1327+ """Model without a config attribute must not raise."""
1328+ monkeypatch .setattr (
1329+ "nemo_automodel.components.distributed.parallelizer._extract_model_layers" ,
1330+ lambda m : [],
1331+ )
1332+ model = nn .Module ()
1333+ model .forward = lambda x : x # type: ignore[attr-defined]
1334+ # no model.config at all
1335+ self ._run_parallelize (model ) # should not raise
1336+
1337+ # ------------------------------------------------------------------ #
1338+ # self_attn checkpoint wrapping
1339+ # ------------------------------------------------------------------ #
1340+
1341+ def test_self_attn_not_wrapped_when_kv_sharing (self ):
1342+ """KV-shared models: self_attn must NOT be wrapped (would corrupt cache)."""
1343+ model = _make_model_for_ac (use_cache = True , num_kv_shared_layers = 20 )
1344+ self ._run_parallelize (model )
1345+ for layer in model .model .layers :
1346+ assert not isinstance (layer .self_attn , self ._Wrapped ), (
1347+ "self_attn should NOT be checkpoint-wrapped for KV-shared models"
1348+ )
1349+
1350+ def test_self_attn_wrapped_without_kv_sharing (self ):
1351+ """Standard models: self_attn IS wrapped."""
1352+ model = _make_model_for_ac (use_cache = True , num_kv_shared_layers = 0 )
1353+ self ._run_parallelize (model )
1354+ for layer in model .model .layers :
1355+ assert isinstance (layer .self_attn , self ._Wrapped ), (
1356+ "self_attn should be checkpoint-wrapped for standard models"
1357+ )
1358+
1359+ def test_mlp_always_wrapped (self ):
1360+ """MLP is checkpoint-wrapped regardless of KV sharing."""
1361+ for kv_shared in (0 , 20 ):
1362+ model = _make_model_for_ac (num_kv_shared_layers = kv_shared )
1363+ self ._run_parallelize (model )
1364+ for layer in model .model .layers :
1365+ assert isinstance (layer .mlp , self ._Wrapped ), (
1366+ f"mlp should always be wrapped (num_kv_shared_layers={ kv_shared } )"
1367+ )
1368+
1369+ def test_layernorms_always_wrapped (self ):
1370+ """Layernorms are checkpoint-wrapped regardless of KV sharing."""
1371+ for kv_shared in (0 , 20 ):
1372+ model = _make_model_for_ac (num_kv_shared_layers = kv_shared )
1373+ self ._run_parallelize (model )
1374+ for layer in model .model .layers :
1375+ assert isinstance (layer .input_layernorm , self ._Wrapped )
1376+ assert isinstance (layer .post_attention_layernorm , self ._Wrapped )
1377+
1378+ def test_no_wrapping_without_activation_checkpointing (self ):
1379+ """When activation_checkpointing=False, nothing is wrapped."""
1380+ model = _make_model_for_ac (num_kv_shared_layers = 0 )
1381+ self ._run_parallelize (model , activation_checkpointing = False )
1382+ for layer in model .model .layers :
1383+ assert not isinstance (layer .mlp , self ._Wrapped )
1384+ assert not isinstance (layer .self_attn , self ._Wrapped )
1385+ assert model .config .use_cache is True # untouched
1386+
1387+ # ------------------------------------------------------------------ #
1388+ # HF native gradient-checkpointing path
1389+ # ------------------------------------------------------------------ #
1390+
1391+ # ------------------------------------------------------------------ #
1392+ # Exception / edge-case branches
1393+ # ------------------------------------------------------------------ #
1394+
1395+ def test_frozen_config_use_cache_except_branch (self ):
1396+ """When ``model.config.use_cache = False`` raises, the except branch runs."""
1397+ model = _make_model_for_ac (use_cache = True , num_kv_shared_layers = 0 )
1398+
1399+ class _FrozenConfig :
1400+ use_cache = True
1401+ text_config = SimpleNamespace (num_kv_shared_layers = 0 )
1402+
1403+ def __setattr__ (self , name , value ):
1404+ raise AttributeError ("frozen" )
1405+
1406+ model .config = _FrozenConfig () # type: ignore[attr-defined]
1407+ self ._run_parallelize (model )
1408+ # use_cache stays True because the assignment raised and was caught
1409+ assert model .config .use_cache is True
1410+
1411+ def test_no_config_with_layers_does_not_crash (self ):
1412+ """Model without ``config`` but with extractable layers does not crash."""
1413+
1414+ class _Bare (nn .Module ):
1415+ def __init__ (self ):
1416+ super ().__init__ ()
1417+ self .model = nn .Module ()
1418+ self .model .layers = nn .ModuleList ([_FakeLayer () for _ in range (2 )]) # type: ignore[attr-defined]
1419+
1420+ def forward (self , x ):
1421+ return x
1422+
1423+ model = _Bare ()
1424+ # no model.config → hasattr(model, "config") is False
1425+ self ._run_parallelize (model )
1426+ # mlp should still be wrapped (activation_checkpointing still applies)
1427+ for layer in model .model .layers :
1428+ assert isinstance (layer .mlp , self ._Wrapped )
1429+
1430+ def test_layer_missing_self_attn (self ):
1431+ """Layers without ``self_attn`` are skipped gracefully."""
1432+
1433+ class _MlpOnlyLayer (nn .Module ):
1434+ def __init__ (self ):
1435+ super ().__init__ ()
1436+ self .mlp = nn .Linear (16 , 16 )
1437+
1438+ def forward (self , x ):
1439+ return x
1440+
1441+ model = _make_model_for_ac (num_kv_shared_layers = 0 )
1442+ model .model .layers = nn .ModuleList ([_MlpOnlyLayer () for _ in range (2 )])
1443+ self ._run_parallelize (model )
1444+ for layer in model .model .layers :
1445+ assert isinstance (layer .mlp , self ._Wrapped )
1446+ assert not hasattr (layer , "self_attn" )
1447+
1448+ def test_layer_missing_mlp (self ):
1449+ """Layers without ``mlp`` are skipped gracefully."""
1450+
1451+ class _AttnOnlyLayer (nn .Module ):
1452+ def __init__ (self ):
1453+ super ().__init__ ()
1454+ self .self_attn = nn .Linear (16 , 16 )
1455+
1456+ def forward (self , x ):
1457+ return x
1458+
1459+ model = _make_model_for_ac (num_kv_shared_layers = 0 )
1460+ model .model .layers = nn .ModuleList ([_AttnOnlyLayer () for _ in range (2 )])
1461+ self ._run_parallelize (model )
1462+ for layer in model .model .layers :
1463+ assert isinstance (layer .self_attn , self ._Wrapped )
1464+ assert not hasattr (layer , "mlp" )
1465+
1466+ # ------------------------------------------------------------------ #
1467+ # HF native gradient-checkpointing path
1468+ # ------------------------------------------------------------------ #
1469+
1470+ @staticmethod
1471+ def _setup_hf_native_model (monkeypatch , num_kv_shared_layers ):
1472+ """Helper: configure a model + fake transformers module for the HF native path."""
1473+ import types
1474+
1475+ class _FakeGradLayer (_FakeLayer ):
1476+ pass
1477+
1478+ _FakeGradLayer .__module__ = "transformers.models.gemma4.modeling_gemma4"
1479+
1480+ fake_module = types .ModuleType ("transformers.modeling_layers" )
1481+ fake_module .GradientCheckpointingLayer = _FakeGradLayer # type: ignore[attr-defined]
1482+ monkeypatch .setitem (sys .modules , "transformers.modeling_layers" , fake_module )
1483+
1484+ model = _make_model_for_ac (use_cache = True , num_kv_shared_layers = num_kv_shared_layers )
1485+ for i in range (len (model .model .layers )):
1486+ model .model .layers [i ] = _FakeGradLayer ()
1487+ model .supports_gradient_checkpointing = True # type: ignore[attr-defined]
1488+ model .gradient_checkpointing_enable = MagicMock () # type: ignore[attr-defined]
1489+ return model
1490+
1491+ def test_hf_native_grad_ckpt_preserves_use_cache_with_kv_sharing (self , monkeypatch ):
1492+ """Even when the HF native path is taken, use_cache stays True for KV-shared models."""
1493+ model = self ._setup_hf_native_model (monkeypatch , num_kv_shared_layers = 20 )
1494+ self ._run_parallelize (model )
1495+
1496+ assert model .config .use_cache is True
1497+ model .gradient_checkpointing_enable .assert_called_once ()
1498+
1499+ def test_hf_native_grad_ckpt_disables_use_cache_without_kv_sharing (self , monkeypatch ):
1500+ """HF native path + no KV sharing: use_cache is set to False."""
1501+ model = self ._setup_hf_native_model (monkeypatch , num_kv_shared_layers = 0 )
1502+ self ._run_parallelize (model )
1503+
1504+ assert model .config .use_cache is False
1505+ model .gradient_checkpointing_enable .assert_called_once_with (
1506+ gradient_checkpointing_kwargs = {"use_reentrant" : True }
1507+ )
0 commit comments