Skip to content

Commit a7d38d0

Browse files
committed
enable pp with batch split ds
1 parent 2762207 commit a7d38d0

8 files changed

Lines changed: 346 additions & 347 deletions

File tree

src/MaxText/layers/attention_op.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,8 @@ def maybe_create_nnx(einsum, *args):
524524
self.AqtEinsum_3 = jnp.einsum
525525

526526
def _logical_to_mesh_axes(self, logical_name):
527-
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules)
527+
logical_rules = None if self.config.using_pipeline_parallelism else self.config.logical_axis_rules
528+
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=logical_rules)
528529

529530
def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Array | KVTensor) -> None:
530531
"""Check attention inputs."""

src/MaxText/layers/deepseek_batchsplit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,7 @@ def gmm(
755755
input_buffer_count,
756756
combine_scopes,
757757
):
758-
if config.use_qwix_quantization:
758+
if config.use_qwix_quantization or config.using_pipeline_parallelism:
759759
output = megablox.gmm(
760760
lhs=inputs,
761761
rhs=kernel,

src/MaxText/layers/moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,8 @@ def _maybe_shard_with_logical(self, inputs, logical_name):
462462
)
463463

464464
def _logical_to_mesh_axes(self, logical_name):
465-
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules)
465+
logical_rules = None if self.config.using_pipeline_parallelism else self.config.logical_axis_rules
466+
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=logical_rules)
466467

467468
def get_expert_parallelism_size(self):
468469
return self.mesh.shape.get("expert", 1)

0 commit comments

Comments
 (0)