Skip to content

Commit c775efa

Browse files
authored
chore: Merge pull request #418 from Modalities/pp_multi_stage
Multi stage pipeline parallelism support
2 parents ec01495 + d0499d1 commit c775efa

46 files changed

Lines changed: 2148 additions & 388 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,4 +171,5 @@ tutorials/instruction_tuning/prepared_data
171171
config_files/instruction_tuning
172172
data/lorem_ipsum_instruct.jsonl
173173
tutorials/scaling_up/logs*
174-
tutorials/scaling_up/experiments_old/*
174+
tutorials/scaling_up/experiments_old/*
175+
results/*

config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ device_mesh:
194194
config:
195195
device_type: cuda
196196
data_parallel_replicate_degree: 1
197-
pipeline_parallel_degree: 2
197+
pipeline_parallel_degree: 4
198198
data_parallel_shard_degree: -1
199199
world_size: ${settings.cuda_env.world_size}
200200

@@ -251,7 +251,7 @@ scheduled_pipeline:
251251
loss_fn:
252252
instance_key: loss_fn
253253
pass_type: BY_REFERENCE
254-
pp_schedule_name: gpipe
254+
pp_schedule_name: Interleaved1F1B
255255
batch_size: ${settings.step_profile.local_train_micro_batch_size}
256256
microbatch_size: 2
257257
pp_degree: ${device_mesh.config.pipeline_parallel_degree}
@@ -318,7 +318,7 @@ staged_pipeline:
318318
instance_key: device_mesh
319319
pass_type: BY_REFERENCE
320320
local_rank: ${settings.cuda_env.local_rank}
321-
pp_schedule_name: gpipe
321+
pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name}
322322
num_layers_per_stage: 2
323323

324324
model_raw:
@@ -332,7 +332,7 @@ model_raw:
332332
sequence_length: ${settings.step_profile.sequence_length}
333333
prediction_key: ${loss_fn.config.prediction_key}
334334
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
335-
n_layer: 2
335+
n_layer: 6
336336
n_head_q: 8
337337
n_head_kv: 4
338338
ffn_hidden: 128

config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ staged_pipeline:
308308
instance_key: device_mesh
309309
pass_type: BY_REFERENCE
310310
local_rank: ${settings.cuda_env.local_rank}
311-
pp_schedule_name: gpipe
311+
pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name}
312312
num_layers_per_stage: 2
313313

314314
model_raw:

src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def _save_checkpoint(self, app_state: AppState, training_progress: TrainingProgr
8989
# saving the model via FULL_STATE_DICT and checkpoint via FULL_OPTIM_STATE_DICT
9090
model_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
9191
optim_save_policy = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
92-
model = app_state.model
92+
assert len(app_state.model_parts) == 1, "FSDP1CheckpointSaving only supports a single model part."
93+
model = app_state.model_parts[0]
9394
optimizer = app_state.optimizer
9495
with FSDP.state_dict_type(
9596
module=model,

src/modalities/checkpointing/stateful/app_state.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from torch.optim import Optimizer
1616
from torch.optim.lr_scheduler import LRScheduler
1717

18+
from modalities.optimizers.optimizer_list import OptimizersList
19+
1820

1921
class StatefulComponents(Enum):
2022
MODEL = "model"
@@ -34,15 +36,18 @@ class AppState(Stateful):
3436
https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
3537
"""
3638

37-
def __init__(self, model: nn.Module, optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None):
39+
def __init__(
40+
self, model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None
41+
):
3842
"""Initializes the AppState object.
3943
4044
Args:
41-
model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model.
45+
model (nn.Module | list[nn.Module]): The model or model parts can be either
46+
a non-sharded model, FSDP1 or FSDP2 model.
4247
optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer.
4348
lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None.
4449
"""
45-
self._model = model
50+
self._model_parts = list(model) if isinstance(model, list) else [model]
4651
self._optimizer = optimizer
4752
self._lr_scheduler = lr_scheduler
4853
self._is_loaded = False
@@ -56,8 +61,8 @@ def is_loaded(self) -> bool:
5661
return self._is_loaded
5762

5863
@property
59-
def model(self) -> nn.Module:
60-
return self._model
64+
def model_parts(self) -> list[nn.Module]:
65+
return self._model_parts
6166

6267
@property
6368
def optimizer(self) -> Optimizer:
@@ -153,15 +158,18 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]:
153158
class ModelStateRetriever(StateRetrieverIF):
154159
@staticmethod
155160
def get_state_dict(app_state: AppState) -> dict[str, Any]:
156-
"""Returns the state dict of the model in the AppState object.
161+
"""Returns the flattened state dicts of the model parts in the AppState object.
157162
158163
Args:
159164
app_state (AppState): The app_state object containing the model.
160165
161166
Returns:
162167
dict[str, Any]: The state dict of the model in the AppState object.
163168
"""
164-
return get_model_state_dict(model=app_state.model)
169+
state_dicts = list(map(get_model_state_dict, app_state.model_parts))
170+
state_dict_keys = sum((list(sd.keys()) for sd in state_dicts), [])
171+
assert len(state_dict_keys) == len(set(state_dict_keys)), "State dict keys are not unique across model parts."
172+
return {k: v for sd in state_dicts for k, v in sd.items()}
165173

166174
@staticmethod
167175
def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None:
@@ -171,7 +179,8 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None:
171179
app_state (AppState): The app_state object containing the model.
172180
state_dict (dict[str, Any]): The state dict to load into the model.
173181
"""
174-
set_model_state_dict(model=app_state.model, model_state_dict=state_dict, options=StateDictOptions(strict=False))
182+
for model in app_state.model_parts:
183+
set_model_state_dict(model=model, model_state_dict=state_dict, options=StateDictOptions(strict=False))
175184

176185

177186
class OptimizerStateRetriever(StateRetrieverIF):
@@ -185,13 +194,17 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]:
185194
Returns:
186195
dict[str, Any]: The state dict of the optimizer in the AppState object.
187196
"""
188-
sd = get_optimizer_state_dict(
189-
model=app_state.model,
190-
optimizers=app_state.optimizer,
191-
# NOTE: Flattening is required for pipeline parallelism to work correctly.
192-
# see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214
193-
options=StateDictOptions(flatten_optimizer_state_dict=True),
194-
)
197+
if isinstance(app_state.optimizer, OptimizersList):
198+
sd = app_state.optimizer.state_dict()
199+
else:
200+
assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer."
201+
sd = get_optimizer_state_dict(
202+
model=app_state.model_parts[0],
203+
optimizers=app_state.optimizer,
204+
# NOTE: Flattening is required for pipeline parallelism to work correctly.
205+
# see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214
206+
options=StateDictOptions(flatten_optimizer_state_dict=True),
207+
)
195208
return sd
196209

197210
@staticmethod
@@ -202,12 +215,16 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None:
202215
app_state (AppState): The app_state object containing the optimizer.
203216
state_dict (dict[str, Any]): The state dict to load into the optimizer.
204217
"""
205-
set_optimizer_state_dict(
206-
model=app_state.model,
207-
optimizers=app_state.optimizer,
208-
optim_state_dict=state_dict,
209-
options=StateDictOptions(flatten_optimizer_state_dict=True),
210-
)
218+
if isinstance(app_state.optimizer, OptimizersList):
219+
app_state.optimizer.load_state_dict(state_dict)
220+
else:
221+
assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer."
222+
set_optimizer_state_dict(
223+
model=app_state.model_parts[0],
224+
optimizers=app_state.optimizer,
225+
optim_state_dict=state_dict,
226+
options=StateDictOptions(flatten_optimizer_state_dict=True),
227+
)
211228

212229

213230
class LRSchedulerStateRetriever(StateRetrieverIF):

src/modalities/checkpointing/stateful/app_state_factory.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ class AppStateFactory:
1515

1616
@staticmethod
1717
def get_raw_app_state(
18-
model: nn.Module, optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None
18+
model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None
1919
) -> AppState:
2020
"""Creates a new (non-checkpoint loaded) AppState object from an instantiated
2121
model, optimizer, and optional learning rate scheduler.
2222
2323
Args:
24-
model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model.
24+
model (nn.Module | list[nn.Module]): The model (parts) can be either
25+
a non-sharded model, FSDP1 or FSDP2 model.
2526
optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer.
2627
lr_scheduler (Optional[LRScheduler], optional): Lr scheduler used during training. Defaults to None.
2728

src/modalities/config/component_factory.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Type, TypeVar
22

3-
from pydantic import BaseModel
3+
from pydantic import AliasChoices, BaseModel
4+
from pydantic.fields import FieldInfo
45

56
from modalities.registry.registry import Registry
67
from modalities.util import print_rank_0
@@ -164,30 +165,53 @@ def _instantiate_component_config(self, component_key: str, variant_key: str, co
164165
config_dict=config_dict,
165166
component_config_type=component_config_type,
166167
)
167-
comp_config = component_config_type(**config_dict, strict=True)
168+
comp_config = component_config_type.model_validate(config_dict, extra="forbid")
168169
return comp_config
169170

170171
def _assert_valid_config_keys(
171172
self, component_key: str, variant_key: str, config_dict: dict, component_config_type: Type[BaseModelChild]
172173
) -> None:
173-
required_keys = []
174-
optional_keys = []
175-
for key, field in component_config_type.model_fields.items():
174+
# Collect required and optional keys, including aliases if defined.
175+
required_keys: list[str] = []
176+
optional_keys: list[str] = []
177+
# Map aliases to canonical field names for clearer error messages.
178+
alias_to_field: dict[str, str] = {}
179+
180+
for field_name, field in component_config_type.model_fields.items():
181+
names_for_field = self._parse_str_aliases(alias_to_field, field_name, field)
176182
if field.is_required():
177-
required_keys.append(key)
183+
required_keys.extend(names_for_field)
178184
else:
179-
optional_keys.append(key)
185+
optional_keys.extend(names_for_field)
180186

181-
invalid_keys = []
182-
for key in config_dict.keys():
183-
if key not in required_keys and key not in optional_keys:
184-
invalid_keys.append(key)
187+
all_valid_keys = set(required_keys) | set(optional_keys)
188+
189+
invalid_keys = [key for key in config_dict.keys() if key not in all_valid_keys]
185190
if len(invalid_keys) > 0:
186191
message = f"Invalid keys {invalid_keys} for config `{component_key}.{variant_key}`"
187192
message += f" of type {component_config_type}:\n{config_dict}\n"
188-
message += f"Required keys: {required_keys}\nOptional keys: {optional_keys}"
193+
if alias_to_field:
194+
message += f"Alias to field mapping: {alias_to_field}\n"
195+
message += f"Required keys (including aliases): {required_keys}\n"
196+
message += f"Optional keys (including aliases): {optional_keys}\n"
189197
raise ValueError(message)
190198

199+
def _parse_str_aliases(self, alias_to_field: dict[str, str], field_name: str, field: FieldInfo) -> set[str]:
200+
names_for_field = {field_name}
201+
if field.alias and field.alias != field_name:
202+
names_for_field.add(field.alias)
203+
alias_to_field[field.alias] = field_name
204+
if field.validation_alias and field.validation_alias != field_name:
205+
if isinstance(field.validation_alias, str):
206+
names_for_field.add(field.validation_alias)
207+
alias_to_field[field.validation_alias] = field_name
208+
elif isinstance(field.validation_alias, AliasChoices):
209+
for alias in field.validation_alias.choices:
210+
if isinstance(alias, str):
211+
names_for_field.add(alias)
212+
alias_to_field[alias] = field_name
213+
return names_for_field
214+
191215
def _instantiate_component(self, component_key: str, variant_key: str, component_config: BaseModel) -> Any:
192216
component_type: Type = self.registry.get_component(component_key, variant_key)
193217
component_config_dict = self._base_model_to_dict(component_config)

src/modalities/config/config.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
PydanticModelInitializationIFType,
2828
PydanticOptimizerIFType,
2929
PydanticPytorchDeviceType,
30+
PydanticPytorchModuleOrListType,
3031
PydanticPytorchModuleType,
3132
PydanticSamplerIFType,
3233
PydanticTokenizerIFType,
@@ -43,6 +44,7 @@
4344
ActivationCheckpointingVariants,
4445
)
4546
from modalities.util import parse_enum_by_name
47+
from modalities.utils.deprecated_alias import add_deprecated_alias
4648

4749

4850
class ProcessGroupBackendType(LookupEnum):
@@ -145,20 +147,24 @@ class CheckpointSavingConfig(BaseModel):
145147

146148
class AdamOptimizerConfig(BaseModel):
147149
lr: float
148-
wrapped_model: PydanticPytorchModuleType
150+
wrapped_model: PydanticPytorchModuleOrListType
149151
betas: tuple[float, float]
150152
eps: float
151153
weight_decay: float
152154
weight_decay_groups_excluded: list[str]
155+
# foreach: bool | None = None
156+
# fused: bool | None = None
153157

154158

155159
class AdamWOptimizerConfig(BaseModel):
156160
lr: float
157-
wrapped_model: PydanticPytorchModuleType
161+
wrapped_model: PydanticPytorchModuleOrListType
158162
betas: tuple[float, float]
159163
eps: float
160164
weight_decay: float
161165
weight_decay_groups_excluded: list[str]
166+
# foreach: bool | None = None
167+
# fused: bool | None = None
162168

163169

164170
class DummyLRSchedulerConfig(BaseModel):
@@ -264,7 +270,7 @@ def parse_sharding_strategy_by_name(cls, name: str) -> ShardingStrategy:
264270

265271

266272
class FSDP2WrappedModelConfig(BaseModel):
267-
model: PydanticPytorchModuleType
273+
model: PydanticPytorchModuleOrListType
268274
block_names: list[str]
269275
mixed_precision_settings: FSDP2MixedPrecisionSettings
270276
reshard_after_forward: bool = True
@@ -289,7 +295,7 @@ def validate_dp_mesh_existence(self):
289295

290296

291297
class DebuggingEnrichedModelConfig(BaseModel):
292-
model: PydanticPytorchModuleType
298+
model: PydanticPytorchModuleOrListType
293299
logging_dir_path: Path
294300
tracked_ranks: Optional[Set[int]] = None
295301
log_interval_steps: Optional[int] = 1
@@ -302,7 +308,7 @@ def convert_list_to_set(cls, v: Iterable[int] | None) -> Set[int] | None:
302308

303309

304310
class GPT2ModelTPConfig(BaseModel):
305-
model: PydanticPytorchModuleType # TODO set proper type
311+
model: PydanticPytorchModuleOrListType # TODO set proper type
306312
device_mesh: PydanticDeviceMeshIFType
307313

308314
@model_validator(mode="after")
@@ -325,7 +331,7 @@ class CompiledModelConfig(BaseModel):
325331

326332

327333
class WeightInitializedModelConfig(BaseModel):
328-
model: PydanticPytorchModuleType
334+
model: PydanticPytorchModuleOrListType
329335
model_initializer: PydanticModelInitializationIFType
330336

331337
# avoid warning about protected namespace 'model_', see
@@ -350,12 +356,12 @@ class SelectiveOpACParams(BaseModel):
350356

351357
ac_variant: ActivationCheckpointingVariants
352358
layers_fqn: str
353-
model: PydanticPytorchModuleType
359+
model: PydanticPytorchModuleOrListType
354360
ac_fun_params: FullACParams | SelectiveLayerACParams | SelectiveOpACParams
355361

356362

357363
class RawAppStateConfig(BaseModel):
358-
model: PydanticPytorchModuleType
364+
model: PydanticPytorchModuleOrListType
359365
optimizer: PydanticOptimizerIFType
360366
lr_scheduler: Optional[PydanticLRSchedulerIFType] = None
361367

@@ -480,12 +486,13 @@ class RichResultSubscriberConfig(BaseModel):
480486
global_rank: int
481487

482488

489+
@add_deprecated_alias("model_parts", "wrapped_model")
483490
class GPT2MFUCalculatorConfig(BaseModel):
484491
n_layer: Annotated[int, Field(strict=True, gt=0)]
485492
sequence_length: Annotated[int, Field(strict=True, gt=0)]
486493
n_embd: Annotated[int, Field(strict=True, gt=0)]
487494
world_size: Annotated[int, Field(strict=True, gt=0)]
488-
wrapped_model: PydanticFSDP1ModuleType | PydanticFSDP2ModuleType
495+
model_parts: PydanticFSDP1ModuleType | PydanticFSDP2ModuleType | list[PydanticFSDP2ModuleType]
489496
device_mesh: Optional[PydanticDeviceMeshIFType] = None
490497

491498

src/modalities/config/pydantic_if_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __get_pydantic_core_schema__(
6666
CheckpointSavingExecutionABC, PydanticThirdPartyTypeIF(CheckpointSavingExecutionABC)
6767
]
6868
PydanticPytorchModuleType = Annotated[nn.Module, PydanticThirdPartyTypeIF(nn.Module)]
69+
PydanticPytorchModuleOrListType = PydanticPytorchModuleType | list[PydanticPytorchModuleType]
6970
PydanticFSDP1ModuleType = Annotated[FSDP1, PydanticThirdPartyTypeIF(FSDP1)]
7071
PydanticFSDP2ModuleType = Annotated[FSDP2, PydanticThirdPartyTypeIF(FSDP2)]
7172
PydanticTokenizerIFType = Annotated[TokenizerWrapper, PydanticThirdPartyTypeIF(TokenizerWrapper)]

0 commit comments

Comments
 (0)