Skip to content

Commit d05c015

Browse files
committed
enable pp with batch split ds
1 parent e4d05fc commit d05c015

9 files changed

Lines changed: 352 additions & 352 deletions

File tree

src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,21 @@ rope_truncate: True
5656
rope_attention_scaling: False
5757

5858
override_logical_axis_rules: True
59-
mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']
60-
data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
59+
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
60+
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
6161
logical_axis_rules: [
6262
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63-
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6464
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6565
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
6666
['activation_norm_length', ['context']],
6767
['activation_heads', []],
68+
['activation_stage', 'stage'],
6869
['embed', ['fsdp']],
6970
['embed_no_exp', ['fsdp']],
7071
['q_lora', ['fsdp']],
7172
['kv_lora', ['fsdp']],
73+
['layers', 'stage'],
7274
['q_lora_up_proj', ['fsdp_transpose', 'expert']],
7375
['kv_lora_up_proj', ['fsdp_transpose', 'expert']],
7476
['q_heads', ['fsdp_transpose', 'expert']],

src/maxtext/configs/types.py

Lines changed: 69 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2403,75 +2403,75 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24032403

24042404
# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
24052405
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
2406-
if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
2407-
self.ici_parallelism = [
2408-
self.ici_diloco_parallelism,
2409-
self.ici_pipeline_parallelism,
2410-
self.ici_data_parallelism,
2411-
self.ici_fsdp_parallelism,
2412-
self.ici_fsdp_transpose_parallelism,
2413-
self.ici_sequence_parallelism,
2414-
self.ici_context_parallelism,
2415-
self.ici_context_autoregressive_parallelism,
2416-
self.ici_tensor_parallelism,
2417-
self.ici_tensor_transpose_parallelism,
2418-
self.ici_tensor_sequence_parallelism,
2419-
self.ici_expert_parallelism,
2420-
self.ici_autoregressive_parallelism,
2421-
]
2422-
self.dcn_parallelism = [
2423-
self.dcn_diloco_parallelism,
2424-
self.dcn_pipeline_parallelism,
2425-
self.dcn_data_parallelism,
2426-
self.dcn_fsdp_parallelism,
2427-
self.dcn_fsdp_transpose_parallelism,
2428-
self.dcn_sequence_parallelism,
2429-
self.dcn_context_parallelism,
2430-
self.dcn_context_autoregressive_parallelism,
2431-
self.dcn_tensor_parallelism,
2432-
self.dcn_tensor_transpose_parallelism,
2433-
self.dcn_tensor_sequence_parallelism,
2434-
self.dcn_expert_parallelism,
2435-
self.dcn_autoregressive_parallelism,
2436-
]
2437-
else:
2438-
ici_map = {
2439-
"diloco": self.ici_diloco_parallelism,
2440-
"data": self.ici_data_parallelism,
2441-
"stage": self.ici_pipeline_parallelism,
2442-
"fsdp": self.ici_fsdp_parallelism,
2443-
"fsdp_transpose": self.ici_fsdp_transpose_parallelism,
2444-
"sequence": self.ici_sequence_parallelism,
2445-
"context": self.ici_context_parallelism,
2446-
"context_autoregressive": self.ici_context_autoregressive_parallelism,
2447-
"tensor": self.ici_tensor_parallelism,
2448-
"tensor_transpose": self.ici_tensor_transpose_parallelism,
2449-
"tensor_sequence": self.ici_tensor_sequence_parallelism,
2450-
"model": self.ici_tensor_parallelism,
2451-
"expert": self.ici_expert_parallelism,
2452-
"autoregressive": self.ici_autoregressive_parallelism,
2453-
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2454-
}
2455-
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
2456-
2457-
dcn_map = {
2458-
"diloco": self.dcn_diloco_parallelism,
2459-
"data": self.dcn_data_parallelism,
2460-
"stage": self.dcn_pipeline_parallelism,
2461-
"fsdp": self.dcn_fsdp_parallelism,
2462-
"fsdp_transpose": self.dcn_fsdp_transpose_parallelism,
2463-
"sequence": self.dcn_sequence_parallelism,
2464-
"context": self.dcn_context_parallelism,
2465-
"context_autoregressive": self.dcn_context_autoregressive_parallelism,
2466-
"tensor": self.dcn_tensor_parallelism,
2467-
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
2468-
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
2469-
"model": self.dcn_tensor_parallelism,
2470-
"expert": self.dcn_expert_parallelism,
2471-
"autoregressive": self.dcn_autoregressive_parallelism,
2472-
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2473-
}
2474-
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
2406+
# if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
2407+
# self.ici_parallelism = [
2408+
# self.ici_diloco_parallelism,
2409+
# self.ici_pipeline_parallelism,
2410+
# self.ici_data_parallelism,
2411+
# self.ici_fsdp_parallelism,
2412+
# self.ici_fsdp_transpose_parallelism,
2413+
# self.ici_sequence_parallelism,
2414+
# self.ici_context_parallelism,
2415+
# self.ici_context_autoregressive_parallelism,
2416+
# self.ici_tensor_parallelism,
2417+
# self.ici_tensor_transpose_parallelism,
2418+
# self.ici_tensor_sequence_parallelism,
2419+
# self.ici_expert_parallelism,
2420+
# self.ici_autoregressive_parallelism,
2421+
# ]
2422+
# self.dcn_parallelism = [
2423+
# self.dcn_diloco_parallelism,
2424+
# self.dcn_pipeline_parallelism,
2425+
# self.dcn_data_parallelism,
2426+
# self.dcn_fsdp_parallelism,
2427+
# self.dcn_fsdp_transpose_parallelism,
2428+
# self.dcn_sequence_parallelism,
2429+
# self.dcn_context_parallelism,
2430+
# self.dcn_context_autoregressive_parallelism,
2431+
# self.dcn_tensor_parallelism,
2432+
# self.dcn_tensor_transpose_parallelism,
2433+
# self.dcn_tensor_sequence_parallelism,
2434+
# self.dcn_expert_parallelism,
2435+
# self.dcn_autoregressive_parallelism,
2436+
# ]
2437+
# else:
2438+
ici_map = {
2439+
"diloco": self.ici_diloco_parallelism,
2440+
"data": self.ici_data_parallelism,
2441+
"stage": self.ici_pipeline_parallelism,
2442+
"fsdp": self.ici_fsdp_parallelism,
2443+
"fsdp_transpose": self.ici_fsdp_transpose_parallelism,
2444+
"sequence": self.ici_sequence_parallelism,
2445+
"context": self.ici_context_parallelism,
2446+
"context_autoregressive": self.ici_context_autoregressive_parallelism,
2447+
"tensor": self.ici_tensor_parallelism,
2448+
"tensor_transpose": self.ici_tensor_transpose_parallelism,
2449+
"tensor_sequence": self.ici_tensor_sequence_parallelism,
2450+
"model": self.ici_tensor_parallelism,
2451+
"expert": self.ici_expert_parallelism,
2452+
"autoregressive": self.ici_autoregressive_parallelism,
2453+
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2454+
}
2455+
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
2456+
2457+
dcn_map = {
2458+
"diloco": self.dcn_diloco_parallelism,
2459+
"data": self.dcn_data_parallelism,
2460+
"stage": self.dcn_pipeline_parallelism,
2461+
"fsdp": self.dcn_fsdp_parallelism,
2462+
"fsdp_transpose": self.dcn_fsdp_transpose_parallelism,
2463+
"sequence": self.dcn_sequence_parallelism,
2464+
"context": self.dcn_context_parallelism,
2465+
"context_autoregressive": self.dcn_context_autoregressive_parallelism,
2466+
"tensor": self.dcn_tensor_parallelism,
2467+
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
2468+
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
2469+
"model": self.dcn_tensor_parallelism,
2470+
"expert": self.dcn_expert_parallelism,
2471+
"autoregressive": self.dcn_autoregressive_parallelism,
2472+
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2473+
}
2474+
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
24752475

24762476
# Diloco params
24772477
self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism)

src/maxtext/layers/attention_op.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,8 +591,9 @@ def maybe_create_nnx(einsum, *args):
591591
self.AqtEinsum_3 = jnp.einsum
592592

593593
def _logical_to_mesh_axes(self, logical_name):
594+
logical_rules = None if self.config.using_pipeline_parallelism else self.config.logical_axis_rules
594595
return logical_to_mesh_axes(
595-
logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules
596+
logical_name, mesh=self.mesh, rules=logical_rules
596597
)
597598

598599
def check_attention_inputs(

src/maxtext/layers/decoders.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -952,11 +952,11 @@ def __call__(
952952
else:
953953
logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
954954
logits = sharding.maybe_shard_with_logical(
955-
logits,
956-
("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"),
957-
mesh=self.mesh,
958-
shard_mode=self.config.shard_mode,
959-
debug_sharding=self.config.debug_sharding,
955+
logits,
956+
("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"),
957+
mesh=self.mesh,
958+
shard_mode=self.config.shard_mode,
959+
debug_sharding=self.config.debug_sharding,
960960
)
961961

962962
# The API of the Decoder is now a tuple, providing both the main output

src/maxtext/layers/moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,8 @@ def _maybe_shard_with_logical(self, inputs, logical_name):
462462
)
463463

464464
def _logical_to_mesh_axes(self, logical_name):
465-
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules)
465+
logical_rules = None if self.config.using_pipeline_parallelism else self.config.logical_axis_rules
466+
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=logical_rules)
466467

467468
def get_expert_parallelism_size(self):
468469
return self.mesh.shape.get("expert", 1)

0 commit comments

Comments
 (0)