Skip to content

Commit e521a58

Browse files
committed
enable pp with batch split ds
1 parent 0a64711 commit e521a58

9 files changed

Lines changed: 352 additions & 352 deletions

File tree

src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,21 @@ rope_truncate: True
5656
rope_attention_scaling: False
5757

5858
override_logical_axis_rules: True
59-
mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']
60-
data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
59+
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
60+
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
6161
logical_axis_rules: [
6262
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63-
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6464
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6565
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
6666
['activation_norm_length', ['context']],
6767
['activation_heads', []],
68+
['activation_stage', 'stage'],
6869
['embed', ['fsdp']],
6970
['embed_no_exp', ['fsdp']],
7071
['q_lora', ['fsdp']],
7172
['kv_lora', ['fsdp']],
73+
['layers', 'stage'],
7274
['q_lora_up_proj', ['fsdp_transpose', 'expert']],
7375
['kv_lora_up_proj', ['fsdp_transpose', 'expert']],
7476
['q_heads', ['fsdp_transpose', 'expert']],

src/maxtext/configs/types.py

Lines changed: 69 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2409,75 +2409,75 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24092409

24102410
# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
24112411
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
2412-
if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
2413-
self.ici_parallelism = [
2414-
self.ici_diloco_parallelism,
2415-
self.ici_pipeline_parallelism,
2416-
self.ici_data_parallelism,
2417-
self.ici_fsdp_parallelism,
2418-
self.ici_fsdp_transpose_parallelism,
2419-
self.ici_sequence_parallelism,
2420-
self.ici_context_parallelism,
2421-
self.ici_context_autoregressive_parallelism,
2422-
self.ici_tensor_parallelism,
2423-
self.ici_tensor_transpose_parallelism,
2424-
self.ici_tensor_sequence_parallelism,
2425-
self.ici_expert_parallelism,
2426-
self.ici_autoregressive_parallelism,
2427-
]
2428-
self.dcn_parallelism = [
2429-
self.dcn_diloco_parallelism,
2430-
self.dcn_pipeline_parallelism,
2431-
self.dcn_data_parallelism,
2432-
self.dcn_fsdp_parallelism,
2433-
self.dcn_fsdp_transpose_parallelism,
2434-
self.dcn_sequence_parallelism,
2435-
self.dcn_context_parallelism,
2436-
self.dcn_context_autoregressive_parallelism,
2437-
self.dcn_tensor_parallelism,
2438-
self.dcn_tensor_transpose_parallelism,
2439-
self.dcn_tensor_sequence_parallelism,
2440-
self.dcn_expert_parallelism,
2441-
self.dcn_autoregressive_parallelism,
2442-
]
2443-
else:
2444-
ici_map = {
2445-
"diloco": self.ici_diloco_parallelism,
2446-
"data": self.ici_data_parallelism,
2447-
"stage": self.ici_pipeline_parallelism,
2448-
"fsdp": self.ici_fsdp_parallelism,
2449-
"fsdp_transpose": self.ici_fsdp_transpose_parallelism,
2450-
"sequence": self.ici_sequence_parallelism,
2451-
"context": self.ici_context_parallelism,
2452-
"context_autoregressive": self.ici_context_autoregressive_parallelism,
2453-
"tensor": self.ici_tensor_parallelism,
2454-
"tensor_transpose": self.ici_tensor_transpose_parallelism,
2455-
"tensor_sequence": self.ici_tensor_sequence_parallelism,
2456-
"model": self.ici_tensor_parallelism,
2457-
"expert": self.ici_expert_parallelism,
2458-
"autoregressive": self.ici_autoregressive_parallelism,
2459-
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2460-
}
2461-
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
2462-
2463-
dcn_map = {
2464-
"diloco": self.dcn_diloco_parallelism,
2465-
"data": self.dcn_data_parallelism,
2466-
"stage": self.dcn_pipeline_parallelism,
2467-
"fsdp": self.dcn_fsdp_parallelism,
2468-
"fsdp_transpose": self.dcn_fsdp_transpose_parallelism,
2469-
"sequence": self.dcn_sequence_parallelism,
2470-
"context": self.dcn_context_parallelism,
2471-
"context_autoregressive": self.dcn_context_autoregressive_parallelism,
2472-
"tensor": self.dcn_tensor_parallelism,
2473-
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
2474-
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
2475-
"model": self.dcn_tensor_parallelism,
2476-
"expert": self.dcn_expert_parallelism,
2477-
"autoregressive": self.dcn_autoregressive_parallelism,
2478-
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2479-
}
2480-
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
2412+
# if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
2413+
# self.ici_parallelism = [
2414+
# self.ici_diloco_parallelism,
2415+
# self.ici_pipeline_parallelism,
2416+
# self.ici_data_parallelism,
2417+
# self.ici_fsdp_parallelism,
2418+
# self.ici_fsdp_transpose_parallelism,
2419+
# self.ici_sequence_parallelism,
2420+
# self.ici_context_parallelism,
2421+
# self.ici_context_autoregressive_parallelism,
2422+
# self.ici_tensor_parallelism,
2423+
# self.ici_tensor_transpose_parallelism,
2424+
# self.ici_tensor_sequence_parallelism,
2425+
# self.ici_expert_parallelism,
2426+
# self.ici_autoregressive_parallelism,
2427+
# ]
2428+
# self.dcn_parallelism = [
2429+
# self.dcn_diloco_parallelism,
2430+
# self.dcn_pipeline_parallelism,
2431+
# self.dcn_data_parallelism,
2432+
# self.dcn_fsdp_parallelism,
2433+
# self.dcn_fsdp_transpose_parallelism,
2434+
# self.dcn_sequence_parallelism,
2435+
# self.dcn_context_parallelism,
2436+
# self.dcn_context_autoregressive_parallelism,
2437+
# self.dcn_tensor_parallelism,
2438+
# self.dcn_tensor_transpose_parallelism,
2439+
# self.dcn_tensor_sequence_parallelism,
2440+
# self.dcn_expert_parallelism,
2441+
# self.dcn_autoregressive_parallelism,
2442+
# ]
2443+
# else:
2444+
ici_map = {
2445+
"diloco": self.ici_diloco_parallelism,
2446+
"data": self.ici_data_parallelism,
2447+
"stage": self.ici_pipeline_parallelism,
2448+
"fsdp": self.ici_fsdp_parallelism,
2449+
"fsdp_transpose": self.ici_fsdp_transpose_parallelism,
2450+
"sequence": self.ici_sequence_parallelism,
2451+
"context": self.ici_context_parallelism,
2452+
"context_autoregressive": self.ici_context_autoregressive_parallelism,
2453+
"tensor": self.ici_tensor_parallelism,
2454+
"tensor_transpose": self.ici_tensor_transpose_parallelism,
2455+
"tensor_sequence": self.ici_tensor_sequence_parallelism,
2456+
"model": self.ici_tensor_parallelism,
2457+
"expert": self.ici_expert_parallelism,
2458+
"autoregressive": self.ici_autoregressive_parallelism,
2459+
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2460+
}
2461+
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
2462+
2463+
dcn_map = {
2464+
"diloco": self.dcn_diloco_parallelism,
2465+
"data": self.dcn_data_parallelism,
2466+
"stage": self.dcn_pipeline_parallelism,
2467+
"fsdp": self.dcn_fsdp_parallelism,
2468+
"fsdp_transpose": self.dcn_fsdp_transpose_parallelism,
2469+
"sequence": self.dcn_sequence_parallelism,
2470+
"context": self.dcn_context_parallelism,
2471+
"context_autoregressive": self.dcn_context_autoregressive_parallelism,
2472+
"tensor": self.dcn_tensor_parallelism,
2473+
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
2474+
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
2475+
"model": self.dcn_tensor_parallelism,
2476+
"expert": self.dcn_expert_parallelism,
2477+
"autoregressive": self.dcn_autoregressive_parallelism,
2478+
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2479+
}
2480+
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
24812481

24822482
# Diloco params
24832483
self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism)

src/maxtext/layers/attention_op.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,8 +591,9 @@ def maybe_create_nnx(einsum, *args):
591591
self.AqtEinsum_3 = jnp.einsum
592592

593593
def _logical_to_mesh_axes(self, logical_name):
594+
logical_rules = None if self.config.using_pipeline_parallelism else self.config.logical_axis_rules
594595
return logical_to_mesh_axes(
595-
logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules
596+
logical_name, mesh=self.mesh, rules=logical_rules
596597
)
597598

598599
def check_attention_inputs(

src/maxtext/layers/decoders.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -995,11 +995,11 @@ def __call__(
995995
else:
996996
logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
997997
logits = sharding.maybe_shard_with_logical(
998-
logits,
999-
("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"),
1000-
mesh=self.mesh,
1001-
shard_mode=self.config.shard_mode,
1002-
debug_sharding=self.config.debug_sharding,
998+
logits,
999+
("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"),
1000+
mesh=self.mesh,
1001+
shard_mode=self.config.shard_mode,
1002+
debug_sharding=self.config.debug_sharding,
10031003
)
10041004

10051005
# The API of the Decoder is now a tuple, providing both the main output

src/maxtext/layers/moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,8 @@ def _maybe_shard_with_logical(self, inputs, logical_name):
462462
)
463463

464464
def _logical_to_mesh_axes(self, logical_name):
465-
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules)
465+
logical_rules = None if self.config.using_pipeline_parallelism else self.config.logical_axis_rules
466+
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=logical_rules)
466467

467468
def get_expert_parallelism_size(self):
468469
return self.mesh.shape.get("expert", 1)

0 commit comments

Comments
 (0)