1- # Some portions of this implementation are inspired, adapted, or refactored
2- # from Meta's open-source project TorchTitan,
3- # licensed under the BSD 3-Clause License.
4-
51import itertools
62import json
73import time
@@ -172,7 +168,6 @@ def get_fsdp2_wrapped_model(
172168 device_mesh : DeviceMesh ,
173169 mixed_precision_settings : FSDP2MixedPrecisionSettings ,
174170 reshard_after_forward : bool ,
175- layers_per_fsdp_unit : int = 1 ,
176171 ) -> FSDP2 :
177172 """Get the FSDP2-wrapped model.
178173
@@ -186,7 +181,6 @@ def get_fsdp2_wrapped_model(
186181 device_mesh (DeviceMesh): The device mesh.
187182 mixed_precision_settings (FSDP2MixedPrecisionSettings): Mixed precision settings.
188183 reshard_after_forward (bool): Whether to reshard after forward.
189- layers_per_fsdp_unit (int): Number of layers per FSDP unit. Default is 1.
190184
191185 Returns:
192186 FSDP2: The FSDP2-wrapped model.
@@ -211,32 +205,17 @@ def get_fsdp2_wrapped_model(
211205 fsdp_config = {"mesh" : device_mesh [fsdp2_degrees ], "mp_policy" : mp_policy }
212206
213207 modules = list (model .modules ())
214-
215208 # we first shard all the blocks
216- grouped_modules : list [nn .Module ] = []
217- module_id = 0
218209 for module_id , module in enumerate (modules ):
219210 if isinstance (module , block_types ):
220- grouped_modules .append (module )
221- if len (grouped_modules ) == layers_per_fsdp_unit :
222- # As an optimization, we do not reshard after forward for the last
223- # transformer block since FSDP would prefetch it immediately.
224- reshard_block_after_forward = reshard_after_forward and int (module_id ) < len (modules ) - 1
225- fully_shard (
226- grouped_modules ,
227- ** fsdp_config ,
228- reshard_after_forward = reshard_block_after_forward ,
229- )
230- grouped_modules = list ()
231-
232- if len (grouped_modules ) > 0 :
233- reshard_block_after_forward = False
234- fully_shard (
235- grouped_modules ,
236- ** fsdp_config ,
237- reshard_after_forward = reshard_block_after_forward ,
238- )
239-
211+ # As an optimization, we do not reshard after forward for the last
212+ # transformer block since FSDP would prefetch it immediately.
213+ reshard_block_after_forward = reshard_after_forward and int (module_id ) < len (modules ) - 1
214+ fully_shard (
215+ module ,
216+ ** fsdp_config ,
217+ reshard_after_forward = reshard_block_after_forward ,
218+ )
240219 # finally, we shard the entire model
241220 fully_shard (model , ** fsdp_config , reshard_after_forward = reshard_after_forward )
242221 logger .info (
@@ -763,4 +742,4 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh)
763742 parallelize_plan = transformer_block_tp_plan ,
764743 )
765744
766- return model
745+ return model
0 commit comments