|
32 | 32 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel |
33 | 33 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask |
34 | 34 | import jax.numpy as jnp |
35 | | -from jax.sharding import Mesh, NamedSharding |
| 35 | +from jax.sharding import Mesh |
36 | 36 | from maxtext.common.common_types import ( |
37 | 37 | Array, |
38 | 38 | AttentionType, |
|
74 | 74 | from maxtext.layers.initializers import variable_to_logically_partitioned |
75 | 75 | from maxtext.layers.quantizations import AqtQuantization as Quant |
76 | 76 | from maxtext.utils import max_utils |
77 | | -from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_name |
| 77 | +from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_pspec |
78 | 78 | import numpy as np |
79 | 79 | from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel |
80 | 80 | from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask |
@@ -1455,26 +1455,19 @@ def kernel_fn(q, k, v, d, s): |
1455 | 1455 |
|
1456 | 1456 | return attention_output, None |
1457 | 1457 |
|
1458 | | - def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None): |
1459 | | - # decoder_segment_ids can be None |
1460 | | - if pspec is None: |
1461 | | - return None |
1462 | | - sharding = NamedSharding(self.mesh, pspec) |
1463 | | - return maybe_shard_with_name( |
1464 | | - inputs, |
1465 | | - sharding, |
1466 | | - shard_mode=self.config.shard_mode, |
1467 | | - debug_sharding=self.config.debug_sharding, |
1468 | | - extra_stack_level=1, |
1469 | | - ) |
1470 | | - |
1471 | | - query = _maybe_shard_with_pspec(query, axis_names_q) |
1472 | | - key = _maybe_shard_with_pspec(key, axis_names_kv) |
1473 | | - value = _maybe_shard_with_pspec(value, axis_names_kv) |
1474 | | - decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q) |
1475 | | - decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv) |
1476 | | - sinks = _maybe_shard_with_pspec(sinks, sink_axis_names) |
1477 | | - indexer_mask = _maybe_shard_with_pspec(indexer_mask, indexer_mask_axis_names) |
| 1458 | + query = maybe_shard_with_pspec(query, self.mesh, self.config.shard_mode, axis_names_q, self.config.debug_sharding) |
| 1459 | + key = maybe_shard_with_pspec(key, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding) |
| 1460 | + value = maybe_shard_with_pspec(value, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding) |
| 1461 | + decoder_segment_ids_q = maybe_shard_with_pspec( |
| 1462 | + decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_q, self.config.debug_sharding |
| 1463 | + ) |
| 1464 | + decoder_segment_ids_kv = maybe_shard_with_pspec( |
| 1465 | + decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_kv, self.config.debug_sharding |
| 1466 | + ) |
| 1467 | + sinks = maybe_shard_with_pspec(sinks, self.mesh, self.config.shard_mode, sink_axis_names, self.config.debug_sharding) |
| 1468 | + indexer_mask = maybe_shard_with_pspec( |
| 1469 | + indexer_mask, self.mesh, self.config.shard_mode, indexer_mask_axis_names, self.config.debug_sharding |
| 1470 | + ) |
1478 | 1471 |
|
1479 | 1472 | ret = wrap_flash_attention( |
1480 | 1473 | query, |
|
0 commit comments