Skip to content

Commit ce8a7de

Browse files
Merge pull request #3595 from AI-Hypercomputer:explicitpp
PiperOrigin-RevId: 897868174
2 parents 58e2c7e + b8cf440 commit ce8a7de

3 files changed

Lines changed: 39 additions & 23 deletions

File tree

src/maxtext/layers/attention_op.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
3333
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
3434
import jax.numpy as jnp
35-
from jax.sharding import Mesh, NamedSharding
35+
from jax.sharding import Mesh
3636
from maxtext.common.common_types import (
3737
Array,
3838
AttentionType,
@@ -74,7 +74,7 @@
7474
from maxtext.layers.initializers import variable_to_logically_partitioned
7575
from maxtext.layers.quantizations import AqtQuantization as Quant
7676
from maxtext.utils import max_utils
77-
from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_name
77+
from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_pspec
7878
import numpy as np
7979
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel
8080
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask
@@ -1455,26 +1455,19 @@ def kernel_fn(q, k, v, d, s):
14551455

14561456
return attention_output, None
14571457

1458-
def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
1459-
# decoder_segment_ids can be None
1460-
if pspec is None:
1461-
return None
1462-
sharding = NamedSharding(self.mesh, pspec)
1463-
return maybe_shard_with_name(
1464-
inputs,
1465-
sharding,
1466-
shard_mode=self.config.shard_mode,
1467-
debug_sharding=self.config.debug_sharding,
1468-
extra_stack_level=1,
1469-
)
1470-
1471-
query = _maybe_shard_with_pspec(query, axis_names_q)
1472-
key = _maybe_shard_with_pspec(key, axis_names_kv)
1473-
value = _maybe_shard_with_pspec(value, axis_names_kv)
1474-
decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q)
1475-
decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv)
1476-
sinks = _maybe_shard_with_pspec(sinks, sink_axis_names)
1477-
indexer_mask = _maybe_shard_with_pspec(indexer_mask, indexer_mask_axis_names)
1458+
query = maybe_shard_with_pspec(query, self.mesh, self.config.shard_mode, axis_names_q, self.config.debug_sharding)
1459+
key = maybe_shard_with_pspec(key, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding)
1460+
value = maybe_shard_with_pspec(value, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding)
1461+
decoder_segment_ids_q = maybe_shard_with_pspec(
1462+
decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_q, self.config.debug_sharding
1463+
)
1464+
decoder_segment_ids_kv = maybe_shard_with_pspec(
1465+
decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_kv, self.config.debug_sharding
1466+
)
1467+
sinks = maybe_shard_with_pspec(sinks, self.mesh, self.config.shard_mode, sink_axis_names, self.config.debug_sharding)
1468+
indexer_mask = maybe_shard_with_pspec(
1469+
indexer_mask, self.mesh, self.config.shard_mode, indexer_mask_axis_names, self.config.debug_sharding
1470+
)
14781471

14791472
ret = wrap_flash_attention(
14801473
query,

src/maxtext/layers/moe.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from maxtext.kernels import megablox as mblx
3737
from maxtext.utils import max_logging
3838
from maxtext.utils import max_utils
39-
from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding
39+
from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding, maybe_shard_with_pspec
4040
from maxtext.utils.sharding import logical_to_mesh_axes
4141
import numpy as np
4242
import qwix.pallas as qpl
@@ -1391,6 +1391,16 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13911391
gate_logits = self._maybe_shard_with_logical(gate_logits, gate_logits_axes)
13921392
pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, pre_bias_logits_axes)
13931393

1394+
w0_kernel = maybe_shard_with_pspec(w0_kernel, self.mesh, self.config.shard_mode, w0_pspec)
1395+
w1_kernel = maybe_shard_with_pspec(w1_kernel, self.mesh, self.config.shard_mode, w1_pspec)
1396+
wo_kernel = maybe_shard_with_pspec(wo_kernel, self.mesh, self.config.shard_mode, wo_pspec)
1397+
if w0_bias is not None:
1398+
w0_bias = maybe_shard_with_pspec(w0_bias, self.mesh, self.config.shard_mode, w0_bias_pspec)
1399+
if w1_bias is not None:
1400+
w1_bias = maybe_shard_with_pspec(w1_bias, self.mesh, self.config.shard_mode, w1_bias_pspec)
1401+
if wo_bias is not None:
1402+
wo_bias = maybe_shard_with_pspec(wo_bias, self.mesh, self.config.shard_mode, wo_bias_pspec)
1403+
13941404
return wrapper(
13951405
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, self.rngs
13961406
)

src/maxtext/utils/sharding.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,19 @@ def maybe_shard_with_name(
115115
return jax.lax.with_sharding_constraint(inputs, named_sharding)
116116

117117

118+
def maybe_shard_with_pspec(inputs, mesh, shard_mode, pspec: jax.sharding.PartitionSpec | None, debug_sharding=False):
119+
if pspec is None:
120+
return None
121+
sharding = NamedSharding(mesh, pspec)
122+
return maybe_shard_with_name(
123+
inputs,
124+
sharding,
125+
shard_mode=shard_mode,
126+
debug_sharding=debug_sharding,
127+
extra_stack_level=1,
128+
)
129+
130+
118131
def maybe_shard_with_logical(
119132
inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0, sharding_desc=""
120133
):

0 commit comments

Comments
 (0)