@@ -1195,46 +1195,55 @@ def validate_tp_mesh(model, tp_mesh):
11951195 )
11961196
11971197
1198- def _find_largest_module_list (model : nn .Module ) -> Optional [nn .ModuleList ]:
1198+ def _find_largest_module_list (model : nn .Module ) -> Optional [Union [ nn .ModuleList , nn . ModuleDict ] ]:
11991199 """
1200- Heuristic function to find the largest nn.ModuleList in a model.
1200+ Heuristic function to find the largest layer container in a model.
12011201
1202- This function recursively traverses the model to find all nn.ModuleList instances
1203- and returns the one with the most modules. This is useful as a fallback when
1204- the model architecture is unknown, since transformer layers are typically
1205- organized in ModuleLists.
1202+ This function recursively traverses the model to find all nn.ModuleList and
1203+ pipeline-split nn.ModuleDict instances and returns the one with the most
1204+ modules. This is useful as a fallback when the model architecture is unknown,
1205+ since transformer layers are typically organized in ModuleLists. Pipeline
1206+ splitting converts ModuleLists to ModuleDicts keyed by original layer index.
12061207
12071208 Args:
12081209 model (nn.Module): The model to search through.
12091210
12101211 Returns:
1211- Optional[nn.ModuleList] : The largest ModuleList found, or None if no ModuleList exists .
1212+ Optional[Union[ nn.ModuleList, nn.ModuleDict]] : The largest layer container found, or None.
12121213 """
1213- largest_module_list = None
1214+ largest_module_list : Optional [ Union [ nn . ModuleList , nn . ModuleDict ]] = None
12141215 largest_size = 0
12151216
1217+ def _is_pp_layer_module_dict (module : nn .ModuleDict ) -> bool :
1218+ # functional.py converts split ModuleLists to ModuleDicts with stringified
1219+ # numeric indices. Avoid treating arbitrary named ModuleDicts (for example
1220+ # adapter registries) as transformer layer containers in the heuristic path.
1221+ return all (key .isdigit () for key in module .keys ())
1222+
12161223 def _recursive_search (module : nn .Module , path : str = "" ):
12171224 nonlocal largest_module_list , largest_size
12181225
12191226 for name , child in module .named_children ():
12201227 current_path = f"{ path } .{ name } " if path else name
12211228
1222- if isinstance (child , nn .ModuleList ):
1229+ if isinstance (child , nn .ModuleList ) or (
1230+ isinstance (child , nn .ModuleDict ) and _is_pp_layer_module_dict (child )
1231+ ):
12231232 current_size = len (child )
12241233 if current_size > largest_size :
12251234 largest_size = current_size
12261235 largest_module_list = child
1227- logger .debug (f"Found ModuleList at { current_path } with { current_size } modules" )
1236+ logger .debug (f"Found { type ( child ). __name__ } at { current_path } with { current_size } modules" )
12281237
12291238 # Continue recursive search
12301239 _recursive_search (child , current_path )
12311240
12321241 _recursive_search (model )
12331242
12341243 if largest_module_list is not None :
1235- logger .info (f"Largest ModuleList found with { largest_size } modules" )
1244+ logger .info (f"Largest layer container found with { largest_size } modules" )
12361245 else :
1237- logger .warning ("No ModuleList found in the model" )
1246+ logger .warning ("No ModuleList or ModuleDict found in the model" )
12381247
12391248 return largest_module_list
12401249
@@ -1320,6 +1329,8 @@ def _extend_layers(layers, modules):
13201329 for m in modules :
13211330 if isinstance (m , nn .ModuleList ):
13221331 layers .extend (m )
1332+ elif isinstance (m , nn .ModuleDict ):
1333+ layers .extend (m .values ())
13231334 else :
13241335 layers .append (m )
13251336
@@ -1338,15 +1349,20 @@ def _extend_layers(layers, modules):
13381349 elif hasattr (model , "layers" ):
13391350 layers .extend (model .layers )
13401351 else :
1341- # Use heuristic to find the largest ModuleList in the model
1352+ # Use heuristic to find the largest layer container in the model.
13421353 logger .warning (f"Unknown model type: { model_cls } . Using heuristic to find transformer layers." )
13431354 largest_module_list = _find_largest_module_list (model )
13441355 if largest_module_list is None :
1345- # If no ModuleList found, still raise an exception
1356+ # If no layer container is found, still raise an exception.
13461357 print (model )
1347- raise ValueError (f"Unknown model type: { model_cls } and no ModuleList found in model structure" )
1358+ raise ValueError (
1359+ f"Unknown model type: { model_cls } and no ModuleList or ModuleDict found in model structure"
1360+ )
13481361
1349- layers .extend (largest_module_list )
1362+ if isinstance (largest_module_list , nn .ModuleDict ):
1363+ layers .extend (largest_module_list .values ())
1364+ else :
1365+ layers .extend (largest_module_list )
13501366 logger .info (f"Successfully extracted { len (largest_module_list )} layers using heuristic" )
13511367
13521368 assert all (isinstance (m , nn .Module ) for m in layers ), "layers shoudl be nn.Module instances"
0 commit comments