Skip to content

Commit a55686b

Browse files
committed
fix qwen3.5 pp moduledict layer extraction
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
1 parent 3d437a9 commit a55686b

3 files changed

Lines changed: 83 additions & 29 deletions

File tree

nemo_automodel/components/distributed/parallelizer.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,46 +1195,55 @@ def validate_tp_mesh(model, tp_mesh):
11951195
)
11961196

11971197

1198-
def _find_largest_module_list(model: nn.Module) -> Optional[nn.ModuleList]:
1198+
def _find_largest_module_list(model: nn.Module) -> Optional[Union[nn.ModuleList, nn.ModuleDict]]:
11991199
"""
1200-
Heuristic function to find the largest nn.ModuleList in a model.
1200+
Heuristic function to find the largest layer container in a model.
12011201
1202-
This function recursively traverses the model to find all nn.ModuleList instances
1203-
and returns the one with the most modules. This is useful as a fallback when
1204-
the model architecture is unknown, since transformer layers are typically
1205-
organized in ModuleLists.
1202+
This function recursively traverses the model to find all nn.ModuleList and
1203+
pipeline-split nn.ModuleDict instances and returns the one with the most
1204+
modules. This is useful as a fallback when the model architecture is unknown,
1205+
since transformer layers are typically organized in ModuleLists. Pipeline
1206+
splitting converts ModuleLists to ModuleDicts keyed by original layer index.
12061207
12071208
Args:
12081209
model (nn.Module): The model to search through.
12091210
12101211
Returns:
1211-
Optional[nn.ModuleList]: The largest ModuleList found, or None if no ModuleList exists.
1212+
Optional[Union[nn.ModuleList, nn.ModuleDict]]: The largest layer container found, or None.
12121213
"""
1213-
largest_module_list = None
1214+
largest_module_list: Optional[Union[nn.ModuleList, nn.ModuleDict]] = None
12141215
largest_size = 0
12151216

1217+
def _is_pp_layer_module_dict(module: nn.ModuleDict) -> bool:
1218+
# functional.py converts split ModuleLists to ModuleDicts with stringified
1219+
# numeric indices. Avoid treating arbitrary named ModuleDicts (for example
1220+
# adapter registries) as transformer layer containers in the heuristic path.
1221+
return all(key.isdigit() for key in module.keys())
1222+
12161223
def _recursive_search(module: nn.Module, path: str = ""):
12171224
nonlocal largest_module_list, largest_size
12181225

12191226
for name, child in module.named_children():
12201227
current_path = f"{path}.{name}" if path else name
12211228

1222-
if isinstance(child, nn.ModuleList):
1229+
if isinstance(child, nn.ModuleList) or (
1230+
isinstance(child, nn.ModuleDict) and _is_pp_layer_module_dict(child)
1231+
):
12231232
current_size = len(child)
12241233
if current_size > largest_size:
12251234
largest_size = current_size
12261235
largest_module_list = child
1227-
logger.debug(f"Found ModuleList at {current_path} with {current_size} modules")
1236+
logger.debug(f"Found {type(child).__name__} at {current_path} with {current_size} modules")
12281237

12291238
# Continue recursive search
12301239
_recursive_search(child, current_path)
12311240

12321241
_recursive_search(model)
12331242

12341243
if largest_module_list is not None:
1235-
logger.info(f"Largest ModuleList found with {largest_size} modules")
1244+
logger.info(f"Largest layer container found with {largest_size} modules")
12361245
else:
1237-
logger.warning("No ModuleList found in the model")
1246+
logger.warning("No ModuleList or ModuleDict found in the model")
12381247

12391248
return largest_module_list
12401249

@@ -1320,6 +1329,8 @@ def _extend_layers(layers, modules):
13201329
for m in modules:
13211330
if isinstance(m, nn.ModuleList):
13221331
layers.extend(m)
1332+
elif isinstance(m, nn.ModuleDict):
1333+
layers.extend(m.values())
13231334
else:
13241335
layers.append(m)
13251336

@@ -1338,15 +1349,20 @@ def _extend_layers(layers, modules):
13381349
elif hasattr(model, "layers"):
13391350
layers.extend(model.layers)
13401351
else:
1341-
# Use heuristic to find the largest ModuleList in the model
1352+
# Use heuristic to find the largest layer container in the model.
13421353
logger.warning(f"Unknown model type: {model_cls}. Using heuristic to find transformer layers.")
13431354
largest_module_list = _find_largest_module_list(model)
13441355
if largest_module_list is None:
1345-
# If no ModuleList found, still raise an exception
1356+
# If no layer container is found, still raise an exception.
13461357
print(model)
1347-
raise ValueError(f"Unknown model type: {model_cls} and no ModuleList found in model structure")
1358+
raise ValueError(
1359+
f"Unknown model type: {model_cls} and no ModuleList or ModuleDict found in model structure"
1360+
)
13481361

1349-
layers.extend(largest_module_list)
1362+
if isinstance(largest_module_list, nn.ModuleDict):
1363+
layers.extend(largest_module_list.values())
1364+
else:
1365+
layers.extend(largest_module_list)
13501366
logger.info(f"Successfully extracted {len(largest_module_list)} layers using heuristic")
13511367

13521368
assert all(isinstance(m, nn.Module) for m in layers), "layers shoudl be nn.Module instances"

tests/unit_tests/distributed/test_parallelizer.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,8 +1525,9 @@ class TestExtractModelLayers:
15251525
15261526
Covers the PR that replaced ``layers.extend(_reduce_attrs(...))`` with a
15271527
helper that flattens ModuleList elements so each decoder layer ends up as
1528-
its own list entry (what AC wrapping expects), while leaving non-ModuleList
1529-
results (e.g. ModuleDict after PP split) appended as-is.
1528+
its own list entry (what AC wrapping expects). PP splitting represents kept
1529+
layer subsets as ModuleDicts, and those layer containers should be flattened
1530+
the same way.
15301531
"""
15311532

15321533
def _make_layers(self, n: int) -> nn.ModuleList:
@@ -1616,14 +1617,13 @@ def test_multi_fqn_flattens_each_modulelist(self):
16161617
assert [id(r) for r in result[5:]] == [id(item) for item in vis]
16171618
assert not any(isinstance(r, nn.ModuleList) for r in result)
16181619

1619-
def test_non_modulelist_element_appended_as_single_entry(self):
1620+
def test_moduledict_layer_container_flattens(self):
16201621
"""PP post-split: ``_reduce_attrs`` returns a ModuleDict.
16211622
1622-
A ModuleDict is NOT an nn.ModuleList, so ``_extend_layers`` must fall
1623-
through to ``layers.append(m)`` and keep it as a single element —
1624-
same behaviour as before the fix (the AC loop then skips it via
1625-
hasattr, which is the expected PP-path behaviour and handled
1626-
elsewhere for the happy PP case).
1623+
The pipeline splitter replaces a ModuleList with a numeric-key
1624+
ModuleDict. ``_extract_model_layers`` must still return individual
1625+
layers so AC, TP follow-up logic, and FSDP layer handling see the same
1626+
shape as the unsplit path.
16271627
"""
16281628
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
16291629

@@ -1635,9 +1635,8 @@ def test_non_modulelist_element_appended_as_single_entry(self):
16351635

16361636
result = _extract_model_layers(model)
16371637

1638-
# ModuleDict is not flattened — it stays as one element.
1639-
assert len(result) == 1
1640-
assert result[0] is layer_dict
1638+
assert len(result) == 2
1639+
assert [id(r) for r in result] == [id(v) for v in layer_dict.values()]
16411640

16421641
def test_fallback_branch_still_handles_modulelist(self):
16431642
"""Non-MODEL_CLS_TO_LAYERS models hit the ``hasattr(model.model, 'layers')``
@@ -1671,6 +1670,17 @@ def __init__(self, layer_dict):
16711670
assert len(result) == 3
16721671
assert [id(r) for r in result] == [id(v) for v in layer_dict.values()]
16731672

1673+
def test_heuristic_ignores_named_moduledict(self):
1674+
"""The unknown-model heuristic should not treat arbitrary ModuleDicts as layers."""
1675+
1676+
class UnknownWithAdapterRegistry(nn.Module):
1677+
def __init__(self):
1678+
super().__init__()
1679+
self.adapters = nn.ModuleDict({"default": nn.Linear(4, 4)})
1680+
1681+
with pytest.raises(ValueError, match="no ModuleList or ModuleDict found"):
1682+
_extract_model_layers(UnknownWithAdapterRegistry())
1683+
16741684
def test_string_keyed_mistral3_fp8_vlm(self):
16751685
"""The ``"Mistral3FP8VLMForConditionalGeneration"`` string-key entry
16761686
catches the runtime class produced by ``_get_mixin_wrapped_class``

tests/unit_tests/distributed/test_qwen3_5_tp_and_grad_sync.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class TestExtractModelLayersStringFallbackAndNoneSafe:
110110
(which happen after PP stage split strips unused sub-modules).
111111
"""
112112

113-
def _make_fake_qwen35(self, visual_is_none: bool):
113+
def _make_fake_qwen35(self, visual_is_none: bool, layers_as_module_dict: bool = False):
114114
"""Build a stand-in object whose type().__name__ is
115115
'Qwen3_5ForConditionalGeneration' but is NOT the real class — this
116116
mimics the lazy-import / deepcopy class-identity drift case."""
@@ -121,7 +121,10 @@ class Qwen3_5ForConditionalGeneration(nn.Module): # noqa: N801 (name intention
121121
model = Qwen3_5ForConditionalGeneration()
122122
model.model = nn.Module()
123123
model.model.language_model = nn.Module()
124-
model.model.language_model.layers = nn.ModuleList([nn.Linear(4, 4)])
124+
if layers_as_module_dict:
125+
model.model.language_model.layers = nn.ModuleDict({"0": nn.Linear(4, 4)})
126+
else:
127+
model.model.language_model.layers = nn.ModuleList([nn.Linear(4, 4)])
125128
if not visual_is_none:
126129
model.model.visual = nn.Module()
127130
model.model.visual.blocks = nn.ModuleList([nn.Linear(4, 4)])
@@ -146,6 +149,31 @@ def test_none_intermediate_attribute_skipped_gracefully(self):
146149
assert len(layers) == 1
147150
assert isinstance(layers[0], nn.Linear)
148151

152+
def test_module_dict_pp_stage_layers_are_flattened(self):
153+
model = self._make_fake_qwen35(visual_is_none=True, layers_as_module_dict=True)
154+
# PP splitting replaces ModuleList with ModuleDict keyed by original layer ids.
155+
layers = parallelizer._extract_model_layers(model)
156+
assert len(layers) == 1
157+
assert isinstance(layers[0], nn.Linear)
158+
159+
def test_unknown_pp_stage_module_dict_heuristic(self):
160+
class UnknownPPSplitStage(nn.Module):
161+
pass
162+
163+
model = UnknownPPSplitStage()
164+
model.model = nn.Module()
165+
model.model.language_model = nn.Module()
166+
model.model.language_model.layers = nn.ModuleDict(
167+
{
168+
"0": nn.Linear(4, 4),
169+
"1": nn.Linear(4, 4),
170+
}
171+
)
172+
173+
layers = parallelizer._extract_model_layers(model)
174+
assert len(layers) == 2
175+
assert all(isinstance(x, nn.Linear) for x in layers)
176+
149177

150178
class TestAutoPipelineDeferFsdpGradSyncConversion:
151179
"""AutoPipeline's surface uses the existing FSDP2Config-style knob

0 commit comments

Comments
 (0)