Skip to content

Commit ede150e

Browse files
rrutmannCopilot
andcommitted
chore: Change order of model initialization
Co-authored-by: Copilot <copilot@github.com>
1 parent 7541df2 commit ede150e

2 files changed

Lines changed: 193 additions & 3 deletions

File tree

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
settings:
2+
experiment_id: ${modalities_env:experiment_id}
3+
config_file_path: ${modalities_env:config_file_path}
4+
referencing_keys:
5+
sample_key: input_ids
6+
target_key: target_ids
7+
prediction_key: logits
8+
cuda_env:
9+
local_rank: ${cuda_env:LOCAL_RANK}
10+
global_rank: ${cuda_env:RANK}
11+
world_size: ${cuda_env:WORLD_SIZE}
12+
step_profile:
13+
gradient_accumulation_steps: 1
14+
local_train_micro_batch_size: 4
15+
sequence_length: 256
16+
17+
loss_fn:
18+
component_key: loss
19+
variant_key: clm_cross_entropy_loss
20+
config:
21+
target_key: ${settings.referencing_keys.target_key}
22+
prediction_key: ${settings.referencing_keys.prediction_key}
23+
24+
device_mesh:
25+
component_key: device_mesh
26+
variant_key: default
27+
config:
28+
device_type: cuda
29+
data_parallel_replicate_degree: 1
30+
pipeline_parallel_degree: 2
31+
data_parallel_shard_degree: -1
32+
world_size: ${settings.cuda_env.world_size}
33+
34+
initialized_model:
35+
component_key: model
36+
variant_key: model_initialized
37+
config:
38+
model:
39+
component_key: pipeline
40+
variant_key: selector
41+
config:
42+
pipeline:
43+
instance_key: scheduled_pipeline
44+
pass_type: BY_REFERENCE
45+
selection_type: MODEL_PART
46+
model_initializer:
47+
component_key: model_initialization
48+
variant_key: composed
49+
config:
50+
model_type: gpt2
51+
weight_init_type: scaled
52+
mean: 0.0
53+
std: 0.02
54+
seed: 42
55+
num_layers: ${model_raw.config.n_layer}
56+
device_mesh:
57+
instance_key: device_mesh
58+
pass_type: BY_REFERENCE
59+
60+
scheduled_pipeline:
61+
component_key: pipeline
62+
variant_key: scheduled
63+
config:
64+
loss_fn:
65+
instance_key: loss_fn
66+
pass_type: BY_REFERENCE
67+
pp_schedule_name: Interleaved1F1B
68+
batch_size: ${settings.step_profile.local_train_micro_batch_size}
69+
microbatch_size: 2
70+
pp_degree: ${device_mesh.config.pipeline_parallel_degree}
71+
pipeline:
72+
component_key: pipeline
73+
variant_key: builder
74+
config:
75+
pp_stage:
76+
component_key: pipeline
77+
variant_key: selector
78+
config:
79+
pipeline:
80+
instance_key: staged_pipeline
81+
pass_type: BY_REFERENCE
82+
selection_type: PP_STAGE
83+
model_part:
84+
instance_key: fsdp_model
85+
pass_type: BY_REFERENCE
86+
87+
fsdp_model:
88+
component_key: model
89+
variant_key: fsdp2_wrapped
90+
config:
91+
model:
92+
instance_key: gpt2_tp_model
93+
pass_type: BY_REFERENCE
94+
device_mesh:
95+
instance_key: device_mesh
96+
pass_type: BY_REFERENCE
97+
mixed_precision_settings:
98+
param_dtype: BF_16
99+
reduce_dtype: BF_16
100+
block_names: [GPT2Block]
101+
102+
gpt2_tp_model:
103+
component_key: model
104+
variant_key: gpt2_tp
105+
config:
106+
model:
107+
instance_key: model_part
108+
pass_type: BY_REFERENCE
109+
device_mesh:
110+
instance_key: device_mesh
111+
pass_type: BY_REFERENCE
112+
113+
model_part:
114+
component_key: pipeline
115+
variant_key: selector
116+
config:
117+
pipeline:
118+
instance_key: staged_pipeline
119+
pass_type: BY_REFERENCE
120+
selection_type: MODEL_PART
121+
122+
staged_pipeline:
123+
component_key: pipeline
124+
variant_key: staged
125+
config:
126+
whole_model:
127+
instance_key: model_raw
128+
pass_type: BY_REFERENCE
129+
stages_generator:
130+
component_key: stages_generator
131+
variant_key: gpt2_stages_generator
132+
config:
133+
num_model_layers: ${model_raw.config.n_layer}
134+
input_layer_equivalence: 1
135+
output_layer_equivalence: 1
136+
device_mesh:
137+
instance_key: device_mesh
138+
pass_type: BY_REFERENCE
139+
local_rank: ${settings.cuda_env.local_rank}
140+
pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name}
141+
num_layers_per_stage: 4
142+
143+
model_raw:
144+
component_key: model
145+
variant_key: gpt2
146+
config:
147+
use_meta_device: true
148+
use_weight_tying: false
149+
sample_key: ${settings.referencing_keys.sample_key}
150+
poe_type: NOPE
151+
sequence_length: ${settings.step_profile.sequence_length}
152+
prediction_key: ${loss_fn.config.prediction_key}
153+
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
154+
n_layer: 6
155+
n_head_q: 8
156+
n_head_kv: 4
157+
ffn_hidden: 128
158+
n_embd: 128
159+
dropout: 0.0
160+
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
161+
attention_config:
162+
qkv_transforms:
163+
- type_hint: RotaryTransform
164+
config:
165+
n_embd: ${model_raw.config.n_embd}
166+
n_head: ${model_raw.config.n_head_q} #it has to be head_q here
167+
seq_length_dim: -2
168+
base_freq: 10000
169+
attention_implementation: manual
170+
activation_type: swiglu
171+
attention_norm_config:
172+
norm_type: layer_norm
173+
config:
174+
normalized_shape: ${model_raw.config.n_embd}
175+
eps: 1e-5
176+
ffn_norm_config:
177+
norm_type: layer_norm
178+
config:
179+
normalized_shape: ${model_raw.config.n_embd}
180+
eps: 1e-5
181+
lm_head_norm_config:
182+
norm_type: layer_norm
183+
config:
184+
normalized_shape: ${model_raw.config.n_embd}
185+
eps: 1e-5
186+

tests/fsdp2_parallelization/test_parallel_seed_initialization.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _seed_distribution_impl_wrapper(self, process_id: int, world_size: int, tmp_
6868
def _seed_distribution_impl(self, world_size: int, tmp_path: Path):
6969
# initialize components
7070
class ComponentsInstantiationModel(BaseModel):
71-
fsdp_model: PydanticFSDP2ModuleType | list[PydanticFSDP2ModuleType]
71+
initialized_model: PydanticFSDP2ModuleType | list[PydanticFSDP2ModuleType]
7272
device_mesh: PydanticDeviceMeshIFType
7373

7474
config_file_path = self._get_tmp_sharding_config_path(dp_degree=2, tp_degree=2, pp_degree=2, tmp_path=tmp_path)
@@ -78,7 +78,10 @@ class ComponentsInstantiationModel(BaseModel):
7878
main_obj.build_components(components_model_type=ComponentsInstantiationModel),
7979
)
8080
model = cast(
81-
Any, components.fsdp_model[0] if isinstance(components.fsdp_model, list) else components.fsdp_model
81+
Any,
82+
components.initialized_model[0]
83+
if isinstance(components.initialized_model, list)
84+
else components.initialized_model,
8285
)
8386
device_mesh = components.device_mesh
8487
# for each pp stage get first transformer block's MLP weight parameter shards and full tensor
@@ -148,7 +151,8 @@ def _get_tmp_sharding_config_path(self, dp_degree: int, tp_degree: int, pp_degre
148151
temp_file_path = tmp_path / "pp_tp_sharding_config.yaml"
149152
working_dir = Path(os.path.dirname(__file__))
150153
config_file_path = (
151-
working_dir / "pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml"
154+
working_dir
155+
/ "pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass_defer_init.yaml"
152156
)
153157

154158
with open(config_file_path, "r") as file:

0 commit comments

Comments
 (0)