Skip to content

Commit 8db7d24

Browse files
authored
Merge pull request #452 from Modalities/3B_training_prep
3 b training prep
2 parents 7337fe4 + e88b8aa commit 8db7d24

12 files changed

Lines changed: 337 additions & 34 deletions

File tree

src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,18 @@ def load_optimizer_checkpoint_(self, optimizer: Optimizer, model: FSDP, file_pat
103103
class DCPCheckpointLoading(DistributedCheckpointLoadingIF):
104104
"""Distributed checkpoint loading interface for loading PyTorch models and optimizer checkpoints."""
105105

106-
def __init__(self, global_rank: int):
106+
def __init__(self, global_rank: int, allow_partial_load: bool = False):
107107
"""
108108
Initializes the DCPCheckpointLoading object.
109109
110110
Args:
111111
global_rank (int): The global rank of the process.
112-
112+
allow_partial_load (bool, optional): Whether to allow partial loading of the checkpoint. Defaults to False.
113113
Returns:
114114
None
115115
"""
116116
self._global_rank = global_rank
117+
self._allow_partial_load = allow_partial_load
117118

118119
@torch.no_grad()
119120
def load_checkpoint_(self, app_state: AppState, checkpoint_dir_path: Path):
@@ -129,5 +130,6 @@ def load_checkpoint_(self, app_state: AppState, checkpoint_dir_path: Path):
129130
dcp.load(
130131
state_dict={"app": app_state},
131132
checkpoint_id=checkpoint_dir_path,
133+
planner=dcp.DefaultLoadPlanner(allow_partial_load=self._allow_partial_load),
132134
)
133135
get_logger().info(f"Distributed checkpoint loaded on rank {self._global_rank}.")

src/modalities/checkpointing/stateful/app_state.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ class AppState(Stateful):
3737
"""
3838

3939
def __init__(
40-
self, model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None
40+
self,
41+
model: nn.Module | list[nn.Module],
42+
optimizer: Optimizer,
43+
lr_scheduler: Optional[LRScheduler] = None,
44+
components_to_load: list[StatefulComponents] | None = None,
4145
):
4246
"""Initializes the AppState object.
4347
@@ -46,12 +50,29 @@ def __init__(
4650
a non-sharded model, FSDP1 or FSDP2 model.
4751
optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer.
4852
lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None.
53+
components_to_load (list[StatefulComponents] | None, optional): The list of components to load from the
54+
checkpoint. If None, all components are loaded. Defaults to None.
4955
"""
5056
self._model_parts = list(model) if isinstance(model, list) else [model]
5157
self._optimizer = optimizer
5258
self._lr_scheduler = lr_scheduler
5359
self._is_loaded = False
5460

61+
# policy for which components to load from the checkpoint. If None, defaults to loading all components.
62+
if components_to_load is None:
63+
self._components_to_load = [StatefulComponents.MODEL, StatefulComponents.OPTIMIZER]
64+
if lr_scheduler is not None:
65+
self._components_to_load.append(StatefulComponents.LR_SCHEDULER)
66+
else:
67+
self._components_to_load = components_to_load
68+
69+
invalid_components = [c for c in self._components_to_load if not isinstance(c, StatefulComponents)]
70+
if invalid_components:
71+
raise ValueError(
72+
f"components_to_load must only contain StatefulComponents, but got invalid entries: "
73+
f"{invalid_components}"
74+
)
75+
5576
@property
5677
def is_loaded(self) -> bool:
5778
"""Returns whether the state dict has been loaded.
@@ -106,12 +127,14 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
106127
"Cannot call load_state_dict twice on the same AppState object. " "State dict has already been loaded."
107128
)
108129

109-
ModelStateRetriever.load_state_dict_(app_state=self, state_dict=state_dict[StatefulComponents.MODEL.value])
110-
OptimizerStateRetriever.load_state_dict_(
111-
app_state=self,
112-
state_dict=state_dict[StatefulComponents.OPTIMIZER.value],
113-
)
114-
if self._lr_scheduler is not None:
130+
if StatefulComponents.MODEL in self._components_to_load:
131+
ModelStateRetriever.load_state_dict_(app_state=self, state_dict=state_dict[StatefulComponents.MODEL.value])
132+
if StatefulComponents.OPTIMIZER in self._components_to_load:
133+
OptimizerStateRetriever.load_state_dict_(
134+
app_state=self,
135+
state_dict=state_dict[StatefulComponents.OPTIMIZER.value],
136+
)
137+
if self._lr_scheduler is not None and StatefulComponents.LR_SCHEDULER in self._components_to_load:
115138
LRSchedulerStateRetriever.load_state_dict_(
116139
app_state=self, state_dict=state_dict[StatefulComponents.LR_SCHEDULER.value]
117140
)

src/modalities/checkpointing/stateful/app_state_factory.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,18 @@
77
from torch.optim.lr_scheduler import LRScheduler
88

99
from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import DCPCheckpointLoading
10-
from modalities.checkpointing.stateful.app_state import AppState
10+
from modalities.checkpointing.stateful.app_state import AppState, StatefulComponents
1111

1212

1313
class AppStateFactory:
1414
"""Factory class to create AppState objects."""
1515

1616
@staticmethod
1717
def get_raw_app_state(
18-
model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None
18+
model: nn.Module | list[nn.Module],
19+
optimizer: Optimizer,
20+
lr_scheduler: Optional[LRScheduler] = None,
21+
components_to_load: list[StatefulComponents] | None = None,
1922
) -> AppState:
2023
"""Creates a new (non-checkpoint loaded) AppState object from an instantiated
2124
model, optimizer, and optional learning rate scheduler.
@@ -25,24 +28,35 @@ def get_raw_app_state(
2528
a non-sharded model, FSDP1 or FSDP2 model.
2629
optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer.
2730
lr_scheduler (Optional[LRScheduler], optional): Lr scheduler used during training. Defaults to None.
31+
components_to_load (list[StatefulComponents] | None, optional): Subset of components that should
32+
be restored from a checkpoint when ``load_state_dict`` is later invoked. If None, all
33+
available components are loaded. Defaults to None.
2834
2935
Returns:
3036
AppState: The AppState object.
3137
"""
32-
app_state = AppState(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler)
38+
app_state = AppState(
39+
model=model,
40+
optimizer=optimizer,
41+
lr_scheduler=lr_scheduler,
42+
components_to_load=components_to_load,
43+
)
3344
return app_state
3445

3546
@staticmethod
3647
def get_dcp_checkpointed_app_state_(
3748
raw_app_state: AppState,
3849
checkpoint_dir_path: Path,
50+
allow_partial_load: bool = False,
3951
) -> AppState:
4052
"""Loads the checkpointed state dict into the raw AppState object
4153
(i.e., non-checkpoint loaded AppState) in-place.
4254
4355
Args:
44-
raw_app_state (AppState): The raw AppState object.
56+
raw_app_state (AppState): The raw AppState object. Its ``components_to_load`` policy
57+
determines which components are restored.
4558
checkpoint_dir_path (Path): The path to the checkpoint directory.
59+
allow_partial_load (bool, optional): Whether to allow partial loading of the checkpoint. Defaults to False.
4660
4761
Raises:
4862
RuntimeError: Raises an error if the state dict has already been loaded.
@@ -52,8 +66,9 @@ def get_dcp_checkpointed_app_state_(
5266
"""
5367
if raw_app_state.is_loaded:
5468
raise RuntimeError(
55-
"Cannot call load_state_dict twice on the same AppState object. " "State dict has already been loaded."
69+
"Cannot call load_state_dict twice on the same AppState object. State dict has already been loaded."
5670
)
57-
cp_loading = DCPCheckpointLoading(global_rank=dist.get_rank())
71+
72+
cp_loading = DCPCheckpointLoading(global_rank=dist.get_rank(), allow_partial_load=allow_partial_load)
5873
cp_loading.load_checkpoint_(app_state=raw_app_state, checkpoint_dir_path=checkpoint_dir_path)
5974
return raw_app_state

src/modalities/config/config.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from transformers import LlamaTokenizer as LlamaTokenizerFast
1212
from typing_extensions import deprecated
1313

14+
from modalities.checkpointing.stateful.app_state import StatefulComponents
1415
from modalities.config.lookup_enum import LookupEnum
1516
from modalities.config.pydantic_if_types import (
1617
PydanticAppStateType,
@@ -33,6 +34,7 @@
3334
PydanticTokenizerIFType,
3435
)
3536
from modalities.config.utils import parse_torch_device
37+
from modalities.models.weight_tying import has_tied_word_embeddings
3638
from modalities.running_env.env_utils import (
3739
FSDP2MixedPrecisionSettings,
3840
MixedPrecisionSettings,
@@ -124,10 +126,6 @@ def parse_sharding_strategy_by_name(cls, name: str) -> ShardingStrategy:
124126
return parse_enum_by_name(name=name, enum_type=ShardingStrategy)
125127

126128

127-
class DCPCheckpointLoadingConfig(BaseModel):
128-
global_rank: Annotated[int, Field(strict=True, ge=0)]
129-
130-
131129
class FSDP1CheckpointSavingConfig(BaseModel):
132130
checkpoint_path: Path
133131
global_rank: Annotated[int, Field(strict=True, ge=0)]
@@ -340,6 +338,13 @@ def validate_tp_mesh_existence(self) -> "GPT2ModelTPConfig":
340338
raise ValueError("data_parallel_replicate_degree > 1 cannot be used with Tensor Parallelism.")
341339
return self
342340

341+
@model_validator(mode="after")
342+
def validate_untied_word_embeddings(self) -> "GPT2ModelTPConfig":
343+
models = self.model if isinstance(self.model, list) else [self.model]
344+
if any(has_tied_word_embeddings(model) for model in models):
345+
raise ValueError("Tied word embeddings are not supported with Tensor Parallelism.")
346+
return self
347+
343348

344349
class CompiledModelConfig(BaseModel):
345350
model: PydanticPytorchModuleOrListType
@@ -382,11 +387,13 @@ class RawAppStateConfig(BaseModel):
382387
model: PydanticPytorchModuleOrListType
383388
optimizer: PydanticOptimizerIFType
384389
lr_scheduler: Optional[PydanticLRSchedulerIFType] = None
390+
components_to_load: Optional[list[StatefulComponents]] = None
385391

386392

387393
class DCPAppStateConfig(BaseModel):
388394
raw_app_state: PydanticAppStateType
389395
checkpoint_dir_path: Path
396+
allow_partial_load: bool = False
390397

391398

392399
class PreTrainedHFTokenizerConfig(BaseModel):

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,12 @@ def __init__(
11211121
self.transformer.lm_head.weight
11221122
) # https://paperswithcode.com/method/weight-tying
11231123

1124+
@property
1125+
def has_tied_word_embeddings(self) -> bool:
1126+
token_embedding_weight = getattr(self.transformer.wte, "weight", None)
1127+
lm_head_weight = getattr(self.transformer.lm_head, "weight", None)
1128+
return token_embedding_weight is not None and token_embedding_weight is lm_head_weight
1129+
11241130
@overload
11251131
def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
11261132
"""

src/modalities/models/gpt2/llama3_like_initialization.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
class Llama3InitializerConfig(BaseModel):
1616
num_layers: Annotated[int, Field(strict=True, gt=0)]
1717
n_embd: Annotated[int, Field(strict=True, gt=0)]
18+
use_weight_tying: bool
1819
depth_init: bool = True
1920

2021

@@ -23,7 +24,7 @@ class Llama3Initializer(ModelInitializationIF):
2324
Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan.
2425
"""
2526

26-
def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
27+
def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_tying: bool) -> None:
2728
"""
2829
Initializes the Llama3Initializer.
2930
Args:
@@ -39,16 +40,6 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
3940
self.regex_to_init = {
4041
# embedding weights
4142
r"transformer\.wte\.weight": (nn.init.normal_, {"mean": 0.0, "std": 1}),
42-
# lm head weights
43-
r"transformer\.lm_head\.weight": (
44-
trunc_normal_,
45-
{
46-
"mean": 0.0,
47-
"std": 1 / math.sqrt(n_embd),
48-
"a": -3 / math.sqrt(n_embd),
49-
"b": 3 / math.sqrt(n_embd),
50-
},
51-
),
5243
# qkv projections
5344
r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": (
5445
trunc_normal_,
@@ -97,6 +88,17 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
9788
},
9889
),
9990
}
91+
if not use_weight_tying:
92+
# lm head weights
93+
self.regex_to_init[r"transformer\.lm_head\.weight"] = (
94+
trunc_normal_,
95+
{
96+
"mean": 0.0,
97+
"std": 1 / math.sqrt(n_embd),
98+
"a": -3 / math.sqrt(n_embd),
99+
"b": 3 / math.sqrt(n_embd),
100+
},
101+
)
100102

101103
def initialize_in_place(self, model: nn.Module):
102104
self._init_by_fqn_regex(model, self.regex_to_init)

src/modalities/models/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def weight_decay_groups(self) -> WeightDecayGroups:
4646
"""
4747
return self._weight_decay_groups
4848

49+
@property
50+
def has_tied_word_embeddings(self) -> bool:
51+
"""Whether the model currently uses tied token embedding and output weights."""
52+
return False
53+
4954
@abstractmethod
5055
def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
5156
"""

src/modalities/models/parallelism/pipeline_parallelism_configs.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Annotated
22

3-
from pydantic import BaseModel, Field
3+
from pydantic import BaseModel, Field, model_validator
44

55
from modalities.config.pydantic_if_types import (
66
PydanticDeviceMeshIFType,
@@ -11,6 +11,7 @@
1111
PydanticStagesGeneratorType,
1212
)
1313
from modalities.models.parallelism.pipeline_parallelism import PipelineSelectionTypes
14+
from modalities.models.weight_tying import has_tied_word_embeddings
1415
from modalities.utils.deprecated_alias import add_deprecated_alias
1516

1617

@@ -26,6 +27,12 @@ class StagedPipelineConfig(BaseModel):
2627
pp_schedule_name: str
2728
num_layers_per_stage: Annotated[int, Field(strict=True, ge=1)]
2829

30+
@model_validator(mode="after")
31+
def validate_untied_word_embeddings(self) -> "StagedPipelineConfig":
32+
if has_tied_word_embeddings(self.whole_model):
33+
raise ValueError("Tied word embeddings are not supported with Pipeline Parallelism.")
34+
return self
35+
2936

3037
class ScheduledPipelineConfig(BaseModel):
3138
loss_fn: PydanticLossIFType
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch.nn as nn
2+
3+
4+
def has_tied_word_embeddings(model: nn.Module) -> bool:
5+
model_has_tied_word_embeddings = getattr(model, "has_tied_word_embeddings", None)
6+
if model_has_tied_word_embeddings is None:
7+
raise TypeError(
8+
f"{type(model).__name__} must define 'has_tied_word_embeddings' to be used with tied-embedding validation."
9+
)
10+
11+
return bool(model_has_tied_word_embeddings)

src/modalities/registry/components.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
SaveEveryKStepsCheckpointingStrategy,
1414
SaveKMostRecentCheckpointsStrategy,
1515
)
16-
from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import DCPCheckpointLoading, FSDP1CheckpointLoading
16+
from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import FSDP1CheckpointLoading
1717
from modalities.checkpointing.fsdp.fsdp_checkpoint_saving import DCPCheckpointSaving, FSDP1CheckpointSaving
1818
from modalities.checkpointing.stateful.app_state_factory import AppStateFactory
1919
from modalities.checkpointing.torch.torch_checkpoint_loading import TorchCheckpointLoading
@@ -29,7 +29,6 @@
2929
ConstantLRSchedulerConfig,
3030
CosineAnnealingLRSchedulerConfig,
3131
DCPAppStateConfig,
32-
DCPCheckpointLoadingConfig,
3332
DCPCheckpointSavingConfig,
3433
DebuggingEnrichedModelConfig,
3534
DistributedSamplerConfig,
@@ -358,7 +357,7 @@ class ComponentEntity:
358357
ComponentEntity("checkpoint_saving_execution", "dcp", DCPCheckpointSaving, DCPCheckpointSavingConfig),
359358
# checkpoint loading
360359
ComponentEntity("checkpoint_loading", "fsdp1", FSDP1CheckpointLoading, FSDP1CheckpointLoadingConfig),
361-
ComponentEntity("checkpoint_loading", "dcp", DCPCheckpointLoading, DCPCheckpointLoadingConfig),
360+
# ComponentEntity("checkpoint_loading", "dcp", DCPCheckpointLoading, DCPCheckpointLoadingConfig),
362361
ComponentEntity("checkpoint_loading", "torch", TorchCheckpointLoading, TorchCheckpointLoadingConfig),
363362
# Progress subscriber
364363
ComponentEntity(

0 commit comments

Comments
 (0)