Skip to content

Commit 33fefd9

Browse files
committed
enable pp with batch split ds
1 parent 5a44af0 commit 33fefd9

9 files changed

Lines changed: 352 additions & 352 deletions

File tree

src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,21 @@ rope_truncate: True
5656
rope_attention_scaling: False
5757

5858
override_logical_axis_rules: True
59-
mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']
60-
data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
59+
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
60+
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
6161
logical_axis_rules: [
6262
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63-
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6464
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6565
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
6666
['activation_norm_length', ['context']],
6767
['activation_heads', []],
68+
['activation_stage', 'stage'],
6869
['embed', ['fsdp']],
6970
['embed_no_exp', ['fsdp']],
7071
['q_lora', ['fsdp']],
7172
['kv_lora', ['fsdp']],
73+
['layers', 'stage'],
7274
['q_lora_up_proj', ['fsdp_transpose', 'expert']],
7375
['kv_lora_up_proj', ['fsdp_transpose', 'expert']],
7476
['q_heads', ['fsdp_transpose', 'expert']],

src/maxtext/configs/types.py

Lines changed: 69 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2502,75 +2502,75 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25022502

25032503
# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
25042504
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
2505-
if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
2506-
self.ici_parallelism = [
2507-
self.ici_diloco_parallelism,
2508-
self.ici_pipeline_parallelism,
2509-
self.ici_data_parallelism,
2510-
self.ici_fsdp_parallelism,
2511-
self.ici_fsdp_transpose_parallelism,
2512-
self.ici_sequence_parallelism,
2513-
self.ici_context_parallelism,
2514-
self.ici_context_autoregressive_parallelism,
2515-
self.ici_tensor_parallelism,
2516-
self.ici_tensor_transpose_parallelism,
2517-
self.ici_tensor_sequence_parallelism,
2518-
self.ici_expert_parallelism,
2519-
self.ici_autoregressive_parallelism,
2520-
]
2521-
self.dcn_parallelism = [
2522-
self.dcn_diloco_parallelism,
2523-
self.dcn_pipeline_parallelism,
2524-
self.dcn_data_parallelism,
2525-
self.dcn_fsdp_parallelism,
2526-
self.dcn_fsdp_transpose_parallelism,
2527-
self.dcn_sequence_parallelism,
2528-
self.dcn_context_parallelism,
2529-
self.dcn_context_autoregressive_parallelism,
2530-
self.dcn_tensor_parallelism,
2531-
self.dcn_tensor_transpose_parallelism,
2532-
self.dcn_tensor_sequence_parallelism,
2533-
self.dcn_expert_parallelism,
2534-
self.dcn_autoregressive_parallelism,
2535-
]
2536-
else:
2537-
ici_map = {
2538-
"diloco": self.ici_diloco_parallelism,
2539-
"data": self.ici_data_parallelism,
2540-
"stage": self.ici_pipeline_parallelism,
2541-
"fsdp": self.ici_fsdp_parallelism,
2542-
"fsdp_transpose": self.ici_fsdp_transpose_parallelism,
2543-
"sequence": self.ici_sequence_parallelism,
2544-
"context": self.ici_context_parallelism,
2545-
"context_autoregressive": self.ici_context_autoregressive_parallelism,
2546-
"tensor": self.ici_tensor_parallelism,
2547-
"tensor_transpose": self.ici_tensor_transpose_parallelism,
2548-
"tensor_sequence": self.ici_tensor_sequence_parallelism,
2549-
"model": self.ici_tensor_parallelism,
2550-
"expert": self.ici_expert_parallelism,
2551-
"autoregressive": self.ici_autoregressive_parallelism,
2552-
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2553-
}
2554-
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
2555-
2556-
dcn_map = {
2557-
"diloco": self.dcn_diloco_parallelism,
2558-
"data": self.dcn_data_parallelism,
2559-
"stage": self.dcn_pipeline_parallelism,
2560-
"fsdp": self.dcn_fsdp_parallelism,
2561-
"fsdp_transpose": self.dcn_fsdp_transpose_parallelism,
2562-
"sequence": self.dcn_sequence_parallelism,
2563-
"context": self.dcn_context_parallelism,
2564-
"context_autoregressive": self.dcn_context_autoregressive_parallelism,
2565-
"tensor": self.dcn_tensor_parallelism,
2566-
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
2567-
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
2568-
"model": self.dcn_tensor_parallelism,
2569-
"expert": self.dcn_expert_parallelism,
2570-
"autoregressive": self.dcn_autoregressive_parallelism,
2571-
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2572-
}
2573-
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
2505+
# if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
2506+
# self.ici_parallelism = [
2507+
# self.ici_diloco_parallelism,
2508+
# self.ici_pipeline_parallelism,
2509+
# self.ici_data_parallelism,
2510+
# self.ici_fsdp_parallelism,
2511+
# self.ici_fsdp_transpose_parallelism,
2512+
# self.ici_sequence_parallelism,
2513+
# self.ici_context_parallelism,
2514+
# self.ici_context_autoregressive_parallelism,
2515+
# self.ici_tensor_parallelism,
2516+
# self.ici_tensor_transpose_parallelism,
2517+
# self.ici_tensor_sequence_parallelism,
2518+
# self.ici_expert_parallelism,
2519+
# self.ici_autoregressive_parallelism,
2520+
# ]
2521+
# self.dcn_parallelism = [
2522+
# self.dcn_diloco_parallelism,
2523+
# self.dcn_pipeline_parallelism,
2524+
# self.dcn_data_parallelism,
2525+
# self.dcn_fsdp_parallelism,
2526+
# self.dcn_fsdp_transpose_parallelism,
2527+
# self.dcn_sequence_parallelism,
2528+
# self.dcn_context_parallelism,
2529+
# self.dcn_context_autoregressive_parallelism,
2530+
# self.dcn_tensor_parallelism,
2531+
# self.dcn_tensor_transpose_parallelism,
2532+
# self.dcn_tensor_sequence_parallelism,
2533+
# self.dcn_expert_parallelism,
2534+
# self.dcn_autoregressive_parallelism,
2535+
# ]
2536+
# else:
2537+
ici_map = {
2538+
"diloco": self.ici_diloco_parallelism,
2539+
"data": self.ici_data_parallelism,
2540+
"stage": self.ici_pipeline_parallelism,
2541+
"fsdp": self.ici_fsdp_parallelism,
2542+
"fsdp_transpose": self.ici_fsdp_transpose_parallelism,
2543+
"sequence": self.ici_sequence_parallelism,
2544+
"context": self.ici_context_parallelism,
2545+
"context_autoregressive": self.ici_context_autoregressive_parallelism,
2546+
"tensor": self.ici_tensor_parallelism,
2547+
"tensor_transpose": self.ici_tensor_transpose_parallelism,
2548+
"tensor_sequence": self.ici_tensor_sequence_parallelism,
2549+
"model": self.ici_tensor_parallelism,
2550+
"expert": self.ici_expert_parallelism,
2551+
"autoregressive": self.ici_autoregressive_parallelism,
2552+
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2553+
}
2554+
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
2555+
2556+
dcn_map = {
2557+
"diloco": self.dcn_diloco_parallelism,
2558+
"data": self.dcn_data_parallelism,
2559+
"stage": self.dcn_pipeline_parallelism,
2560+
"fsdp": self.dcn_fsdp_parallelism,
2561+
"fsdp_transpose": self.dcn_fsdp_transpose_parallelism,
2562+
"sequence": self.dcn_sequence_parallelism,
2563+
"context": self.dcn_context_parallelism,
2564+
"context_autoregressive": self.dcn_context_autoregressive_parallelism,
2565+
"tensor": self.dcn_tensor_parallelism,
2566+
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
2567+
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
2568+
"model": self.dcn_tensor_parallelism,
2569+
"expert": self.dcn_expert_parallelism,
2570+
"autoregressive": self.dcn_autoregressive_parallelism,
2571+
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2572+
}
2573+
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
25742574

25752575
# Diloco params
25762576
self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism)

src/maxtext/layers/attention_op.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,9 @@ def maybe_create_nnx(einsum, *args):
595595
self.AqtEinsum_3 = jnp.einsum
596596

597597
def _logical_to_mesh_axes(self, logical_name):
598+
logical_rules = None if self.config.using_pipeline_parallelism else self.config.logical_axis_rules
598599
return logical_to_mesh_axes(
599-
logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules
600+
logical_name, mesh=self.mesh, rules=logical_rules
600601
)
601602

602603
def check_attention_inputs(

src/maxtext/layers/decoders.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,11 +1033,11 @@ def __call__(
10331033
else:
10341034
logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
10351035
logits = sharding.maybe_shard_with_logical(
1036-
logits,
1037-
("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"),
1038-
mesh=self.mesh,
1039-
shard_mode=self.config.shard_mode,
1040-
debug_sharding=self.config.debug_sharding,
1036+
logits,
1037+
("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"),
1038+
mesh=self.mesh,
1039+
shard_mode=self.config.shard_mode,
1040+
debug_sharding=self.config.debug_sharding,
10411041
)
10421042

10431043
# The API of the Decoder is now a tuple, providing both the main output

src/maxtext/layers/moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,8 @@ def _maybe_shard_with_logical(self, inputs, logical_name):
463463
)
464464

465465
def _logical_to_mesh_axes(self, logical_name):
466-
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules)
466+
logical_rules = None if self.config.using_pipeline_parallelism else self.config.logical_axis_rules
467+
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=logical_rules)
467468

468469
def get_expert_parallelism_size(self):
469470
return self.mesh.shape.get("expert", 1)

0 commit comments

Comments
 (0)