Skip to content

Commit c54393a

Browse files
committed
chore: temporarily reverted combination of fsdp units
1 parent 8f84b2d commit c54393a

2 files changed

Lines changed: 10 additions & 31 deletions

File tree

src/modalities/config/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ class FSDP2WrappedModelConfig(BaseModel):
275275
mixed_precision_settings: FSDP2MixedPrecisionSettings
276276
reshard_after_forward: bool = True
277277
device_mesh: PydanticDeviceMeshIFType
278-
layers_per_fsdp_unit: int = 1
278+
# layers_per_fsdp_unit: int = 1
279279

280280
@model_validator(mode="after")
281281
def validate_mixed_precision_settings(self):

src/modalities/models/model_factory.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
# Some portions of this implementation are inspired, adapted, or refactored
2-
# from Meta's open-source project TorchTitan,
3-
# licensed under the BSD 3-Clause License.
4-
51
import itertools
62
import json
73
import time
@@ -172,7 +168,6 @@ def get_fsdp2_wrapped_model(
172168
device_mesh: DeviceMesh,
173169
mixed_precision_settings: FSDP2MixedPrecisionSettings,
174170
reshard_after_forward: bool,
175-
layers_per_fsdp_unit: int = 1,
176171
) -> FSDP2:
177172
"""Get the FSDP2-wrapped model.
178173
@@ -186,7 +181,6 @@ def get_fsdp2_wrapped_model(
186181
device_mesh (DeviceMesh): The device mesh.
187182
mixed_precision_settings (FSDP2MixedPrecisionSettings): Mixed precision settings.
188183
reshard_after_forward (bool): Whether to reshard after forward.
189-
layers_per_fsdp_unit (int): Number of layers per FSDP unit. Default is 1.
190184
191185
Returns:
192186
FSDP2: The FSDP2-wrapped model.
@@ -211,32 +205,17 @@ def get_fsdp2_wrapped_model(
211205
fsdp_config = {"mesh": device_mesh[fsdp2_degrees], "mp_policy": mp_policy}
212206

213207
modules = list(model.modules())
214-
215208
# we first shard all the blocks
216-
grouped_modules: list[nn.Module] = []
217-
module_id = 0
218209
for module_id, module in enumerate(modules):
219210
if isinstance(module, block_types):
220-
grouped_modules.append(module)
221-
if len(grouped_modules) == layers_per_fsdp_unit:
222-
# As an optimization, we do not reshard after forward for the last
223-
# transformer block since FSDP would prefetch it immediately.
224-
reshard_block_after_forward = reshard_after_forward and int(module_id) < len(modules) - 1
225-
fully_shard(
226-
grouped_modules,
227-
**fsdp_config,
228-
reshard_after_forward=reshard_block_after_forward,
229-
)
230-
grouped_modules = list()
231-
232-
if len(grouped_modules) > 0:
233-
reshard_block_after_forward = False
234-
fully_shard(
235-
grouped_modules,
236-
**fsdp_config,
237-
reshard_after_forward=reshard_block_after_forward,
238-
)
239-
211+
# As an optimization, we do not reshard after forward for the last
212+
# transformer block since FSDP would prefetch it immediately.
213+
reshard_block_after_forward = reshard_after_forward and int(module_id) < len(modules) - 1
214+
fully_shard(
215+
module,
216+
**fsdp_config,
217+
reshard_after_forward=reshard_block_after_forward,
218+
)
240219
# finally, we shard the entire model
241220
fully_shard(model, **fsdp_config, reshard_after_forward=reshard_after_forward)
242221
logger.info(
@@ -763,4 +742,4 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh)
763742
parallelize_plan=transformer_block_tp_plan,
764743
)
765744

766-
return model
745+
return model

0 commit comments

Comments
 (0)