Skip to content

Commit eb52c0b

Browse files
committed
Add packed (THD) ring attention with hardware-aware reorder dispatch
Enable CP + packing for context_parallel_strategy="ring" with load balancing. On GPU, uses Transformer Engine's striped reorder for THD-packed sequences. On TPU/CPU, falls back to pure-JAX reorder_sequence and never imports TE. Changes: - common_types: Add ReorderStrategy enum (AUTO, DUAL_CHUNK_SWAP, STRIPED). - configs: Add context_parallel_reorder_strategy (default "auto"). Reject explicit STRIPED on non-GPU at config validation time. - attention_op: Thread segment_positions through apply_attention, cudnn_flash_attention, and __call__. Use segment_positions in TE's SequenceDescriptor for packing. Restrict packing+CP to load-balanced ring only. Note TE version constraint. - attentions.py, attention_mla.py, gpt3.py: Pass inputs_positions into attention_op calls (None for gpt3). - max_utils: Hardware-dispatched reorder_causal_load_balanced. GPU uses TE's reorder_causal_load_balancing; TPU/CPU uses reorder_sequence. TE import is lazy and GPU-only. - maxtext_utils: Thread reorder_strategy and hardware through shard_reorder_causal_load_balanced and get_reorder_callable. Default hardware="tpu" never triggers TE import. - train_utils: Allow ring+packing; forbid all_gather+packing and synthetic+packing. Resolve AUTO->STRIPED for packing else DUAL_CHUNK_SWAP. Pass config.hardware to reorder callable. Build data_loader after reorder wrapper is applied. - attention_test_util: Pass cfg_cp.hardware so TPU tests use pure-JAX reorder. Helper is TPU-oriented and does not model GPU packed behavior. - tests: Add test_gpu_ring_attention_with_packing (sm90+). Requires TE with reorder_causal_load_balancing; works with TE <=2.11 or >=2.14 (incompatible with 2.12 and 2.13 due to a known bug).
1 parent 3de696b commit eb52c0b

13 files changed

Lines changed: 223 additions & 51 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ class ShardMode(enum.Enum):
122122
EXPLICIT = "explicit"
123123

124124

125+
class ReorderStrategy(enum.Enum):
126+
"""Reorder strategies for load-balanced context parallelism.
127+
Maps to transformer_engine.jax.attention.ReorderStrategy at runtime.
128+
"""
129+
130+
AUTO = "auto"
131+
DUAL_CHUNK_SWAP = "dual_chunk_swap"
132+
STRIPED = "striped"
133+
134+
125135
class HyperConnectionType(enum.Enum):
126136
ATTENTION = "attention"
127137
MLP_MOE = "mlp_moe"

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,7 @@ use_splash_scheduler: False # to use tokamax splash attention scheduler.
10381038
### Determine if we want to use load balance for context parallelism
10391039
context_parallel_load_balance: True
10401040
context_parallel_strategy: "all_gather" # "all_gather" or "ring"
1041+
context_parallel_reorder_strategy: "auto" # "auto", "dual_chunk_swap", or "striped"
10411042

10421043
### Paged Attention ###
10431044
# These settings take effect only when `attention=paged`.

src/maxtext/configs/pyconfig_deprecated.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from maxtext.utils import accelerator_to_spec_map
3232
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR
33-
from maxtext.common.common_types import AttentionType, DecoderBlockType, ShardMode
33+
from maxtext.common.common_types import AttentionType, DecoderBlockType, ReorderStrategy, ShardMode
3434
from maxtext.utils import gcs_utils
3535
from maxtext.utils import max_logging
3636
from maxtext.utils import max_utils
@@ -856,6 +856,7 @@ def user_init(raw_keys):
856856

857857
raw_keys["decoder_block"] = DecoderBlockType(raw_keys["decoder_block"])
858858
raw_keys["shard_mode"] = ShardMode(raw_keys["shard_mode"])
859+
raw_keys["context_parallel_reorder_strategy"] = ReorderStrategy(raw_keys["context_parallel_reorder_strategy"])
859860

860861
@staticmethod
861862
def configure_gpt3_task(raw_keys):

src/maxtext/configs/types.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from typing import Any, Literal, NewType, Optional
3030

3131
import jax
32-
from maxtext.common.common_types import AttentionType, DecoderBlockType, ShardMode
32+
from maxtext.common.common_types import AttentionType, DecoderBlockType, ReorderStrategy, ShardMode
3333
from maxtext.utils import gcs_utils
3434
from maxtext.utils import max_utils
3535
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT
@@ -821,6 +821,10 @@ class HardwareAndMesh(BaseModel):
821821
"all_gather",
822822
description="Strategy for context parallelism ('all_gather' or 'ring').",
823823
)
824+
context_parallel_reorder_strategy: ReorderStrategy = Field(
825+
"auto",
826+
description="Reorder strategy for load-balanced context parallelism.",
827+
)
824828
custom_mesh: str = Field("", description="Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8']")
825829
custom_mesh_and_rule: str = Field("", description="Customized mesh and logical rules for granularity.")
826830
allow_split_physical_axes: bool = Field(False, description="Allow splitting physical axes for device mesh creation.")
@@ -2672,6 +2676,20 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
26722676
raise ValueError(
26732677
"Ring context parallelism strategy (context_parallel_strategy='ring') is only supported on GPUs."
26742678
)
2679+
# STRIPED reorder strategy is a Transformer Engine feature and is GPU-only.
2680+
# The AUTO + packing case (which training resolves to STRIPED) is not validated here
2681+
# because test code paths may load the same config but use a different reorder path.
2682+
# Training's runtime path in max_utils.reorder_causal_load_balanced enforces this.
2683+
if (
2684+
self.context_parallel_size > 1
2685+
and "gpu" not in self.hardware
2686+
and self.context_parallel_load_balance
2687+
and self.context_parallel_reorder_strategy == ReorderStrategy.STRIPED
2688+
):
2689+
raise ValueError(
2690+
"STRIPED reorder strategy requires Transformer Engine and is only supported on GPUs. "
2691+
f"Got hardware={self.hardware!r}."
2692+
)
26752693
if self.hardware == "gpu" and self.packing and self.attention == "cudnn_flash_te" and self.max_segments_per_seq <= 0:
26762694
raise ValueError("max_segments_per_seq must be set when using TransformerEngine attention and packing")
26772695
dcn_product = (

src/maxtext/layers/attention_mla.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,6 +1217,7 @@ def __call__(
12171217
key,
12181218
value,
12191219
decoder_segment_ids,
1220+
inputs_positions,
12201221
model_mode,
12211222
cached_values,
12221223
indexer_mask=indexer_mask,

src/maxtext/layers/attention_op.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,7 @@ def apply_attention(
871871
key: Array | KVTensor,
872872
value: Array | KVTensor,
873873
decoder_segment_ids: Array | None,
874+
segment_positions: Array | None,
874875
lengths: Array | None,
875876
model_mode: str,
876877
use_ragged_attention: bool = False,
@@ -1003,7 +1004,7 @@ def apply_attention(
10031004
Use `dot_product` instead."""
10041005
)
10051006
return (
1006-
self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode),
1007+
self.cudnn_flash_attention(query, key, value, decoder_segment_ids, segment_positions, model_mode),
10071008
None,
10081009
None,
10091010
)
@@ -1513,12 +1514,15 @@ def cudnn_flash_attention(
15131514
key: Array,
15141515
value: Array,
15151516
decoder_segment_ids: Array | None,
1517+
segment_positions: Array | None,
15161518
model_mode: str = MODEL_MODE_TRAIN,
15171519
) -> Array:
15181520
"""CUDNN Flash Attention with Transformer Engine.
1519-
1520-
1. Stable API, supports MHA, GQA, SWA, Packing and Context Parallelism 2.
1521-
Context Parallelism currently only supports causal masking and no packing
1521+
1. Stable API, supports MHA, GQA, SWA, Packing and Context Parallelism
1522+
2. Context Parallelism currently only supports causal masking
1523+
3. Only Ring attention has packing support with striped load balancing
1524+
(context_parallel_strategy="ring" and context_parallel_load_balance=true)
1525+
4. Breaks with TE 2.12 and 2.13 (known bug); works with TE stable release <=2.11 or >=2.14.
15221526
"""
15231527
# These imports are only meant to work in a GPU build.
15241528
# pylint: disable=import-outside-toplevel
@@ -1528,6 +1532,11 @@ def cudnn_flash_attention(
15281532
_, _, _, head_dim = query.shape # pylint: disable=unused-variable
15291533

15301534
using_context_parallelism = self.mesh.shape[self.config.context_sharding] > 1
1535+
using_load_balanced_ring_cp = (
1536+
using_context_parallelism
1537+
and self.config.context_parallel_strategy == "ring"
1538+
and self.config.context_parallel_load_balance
1539+
)
15311540

15321541
# Initialize default attention configuration
15331542
sliding_window_size = None
@@ -1541,18 +1550,27 @@ def cudnn_flash_attention(
15411550

15421551
# Handle packing configurations
15431552
if self.config.packing and self.config.dataset_type != "synthetic":
1553+
if using_context_parallelism and not using_load_balanced_ring_cp:
1554+
raise ValueError("Packing is only supported for load balanced ring attention with context parallelism.")
15441555
qkv_layout = "THD_THD_THD" # Packed format: 'T3HD', 'THD_T2HD' or 'THD_THD_THD'
15451556
if decoder_segment_ids is None:
15461557
decoder_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32)
1547-
attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=decoder_segment_ids, segment_pos=None)
1558+
attn_mask = SequenceDescriptor.from_segment_ids_and_pos(
1559+
segment_ids=decoder_segment_ids, segment_pos=segment_positions
1560+
)
15481561
# Create dummy SequenceDescriptor for lazy_init
15491562
dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32)
1550-
dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=dummy_segment_ids, segment_pos=None)
1563+
dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos(
1564+
segment_ids=dummy_segment_ids, segment_pos=segment_positions
1565+
)
15511566
max_segments_per_seq = self.config.max_segments_per_seq
15521567
elif using_context_parallelism:
15531568
if self.attention_type == AttentionType.LOCAL_SLIDING:
1554-
raise AssertionError("Sliding window attention is not supported for context parallelism")
1555-
# Context parallelism without packing: only supports causal masking
1569+
raise AssertionError(
1570+
"Sliding window attention requires context parallelism with load-balanced ring strategy "
1571+
"and packing enabled."
1572+
)
1573+
# Context parallelism without packing: only supports causal masking, but not sliding window attention
15561574
attn_mask = None
15571575
dummy_attn_mask = None
15581576
mask_type = "causal"
@@ -2003,6 +2021,7 @@ def __call__(
20032021
key,
20042022
value,
20052023
decoder_segment_ids,
2024+
inputs_positions,
20062025
model_mode,
20072026
cached_values=None,
20082027
previous_chunk=None,
@@ -2034,6 +2053,7 @@ def __call__(
20342053
key=key,
20352054
value=value,
20362055
decoder_segment_ids=decoder_segment_ids,
2056+
segment_positions=inputs_positions,
20372057
lengths=None,
20382058
model_mode=model_mode,
20392059
use_ragged_attention=self.use_ragged_attention,
@@ -2059,6 +2079,7 @@ def __call__(
20592079
key=key,
20602080
value=value,
20612081
decoder_segment_ids=decoder_segment_ids,
2082+
segment_positions=inputs_positions,
20622083
lengths=lengths,
20632084
model_mode=model_mode,
20642085
use_ragged_attention=self.use_ragged_attention,

src/maxtext/layers/attentions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,7 @@ def __call__(
11841184
key,
11851185
value,
11861186
decoder_segment_ids,
1187+
inputs_positions,
11871188
model_mode,
11881189
cached_values,
11891190
previous_chunk,

src/maxtext/models/gpt3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def __call__(
328328
value = nn.with_logical_constraint(value, self.value_axis_names)
329329
value = checkpoint_name(value, "value_proj")
330330

331-
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode)
331+
out = self.attention_op(query, key, value, decoder_segment_ids, None, model_mode)
332332

333333
out = nn.with_logical_constraint(out, self.out_axis_names)
334334

src/maxtext/utils/max_utils.py

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -887,27 +887,86 @@ def reorder_sequence(tensor, cp_size: int, seq_dim: int = 1, to_contiguous: bool
887887
return reordered.reshape(ori_tensor_shape)
888888

889889

890-
@partial(jax.jit, static_argnums=1)
891-
def reorder_causal_load_balanced(batch, cp_size):
892-
"""Reorders the example batch sequences"""
893-
return {
894-
key: reorder_sequence(
895-
value, # Pass each key's value inside batch separately
896-
cp_size=cp_size,
897-
)
898-
if key
899-
in [
900-
"inputs",
901-
"targets",
902-
"inputs_position",
903-
"targets_position",
904-
"inputs_segmentation",
905-
"targets_segmentation",
906-
]
907-
else value
908-
for key, value in batch.items()
890+
@partial(jax.jit, static_argnums=(1, 2, 3))
891+
def reorder_causal_load_balanced(batch, cp_size, reorder_strategy, hardware="tpu"):
892+
"""Reorders the example batch sequences using a hardware-appropriate backend.
893+
894+
On GPU (hardware="gpu" or "gpu_multiprocess"), uses Transformer Engine's
895+
reorder_causal_load_balancing which supports both DUAL_CHUNK_SWAP and STRIPED strategies.
896+
On TPU/CPU, falls back to the pure-JAX reorder_sequence (DUAL_CHUNK_SWAP only).
897+
898+
Args:
899+
batch: The batch to reorder.
900+
cp_size: The size of the compute parallelism.
901+
reorder_strategy: The ReorderStrategy enum value (DUAL_CHUNK_SWAP or STRIPED).
902+
hardware: The hardware type string ("tpu", "gpu", "gpu_multiprocess", "cpu").
903+
904+
Returns:
905+
The reordered batch.
906+
907+
Reorder Strategy:
908+
- DUAL_CHUNK_SWAP: This strategy splits each query into two chunks and do the mirror swap between
909+
GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the
910+
multiple of 2 * cp_size.
911+
Examples:
912+
- Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15];
913+
- After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3]
914+
915+
- STRIPED: This strategy distributes the tokens in a striped (interleaved) manner across
916+
the sequence. This is currently used for THD load balance.
917+
Example: Consider 4 GPUs with seqlens=16.
918+
- Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; ...; GPU3: [12, 13, 14, 15]
919+
- After reorder: GPU0: [0, 4, 8, 12]; GPU1: [1, 5, 9, 13]; ...; GPU3: [3, 7, 11, 15]
920+
921+
See: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/attention.py
922+
"""
923+
# pylint: disable=import-outside-toplevel
924+
from maxtext.common.common_types import ReorderStrategy
925+
926+
_reorder_keys = {
927+
"inputs",
928+
"targets",
929+
"inputs_position",
930+
"targets_position",
931+
"inputs_segmentation",
932+
"targets_segmentation",
909933
}
910934

935+
if hardware in ("gpu", "gpu_multiprocess"):
936+
from transformer_engine.jax.attention import ReorderStrategy as TE_ReorderStrategy
937+
from transformer_engine.jax.attention import reorder_causal_load_balancing
938+
939+
reorder_strategy_map = {
940+
ReorderStrategy.DUAL_CHUNK_SWAP: TE_ReorderStrategy.DualChunkSwap,
941+
ReorderStrategy.STRIPED: TE_ReorderStrategy.Striped,
942+
}
943+
944+
return {
945+
key: reorder_causal_load_balancing(
946+
value,
947+
reorder_strategy_map[reorder_strategy],
948+
cp_size=cp_size,
949+
seq_dim=1,
950+
)
951+
if key in _reorder_keys
952+
else value
953+
for key, value in batch.items()
954+
}
955+
else:
956+
if reorder_strategy == ReorderStrategy.STRIPED:
957+
raise ValueError(
958+
f"STRIPED reorder strategy requires Transformer Engine and is only supported on GPU, got hardware={hardware!r}."
959+
)
960+
return {
961+
key: reorder_sequence(
962+
value,
963+
cp_size=cp_size,
964+
)
965+
if key in _reorder_keys
966+
else value
967+
for key, value in batch.items()
968+
}
969+
911970

912971
@staticmethod
913972
def reorder_mask_load_balancing(tensor, cp_size: int, seq_dim: int):

src/maxtext/utils/maxtext_utils.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# pylint: disable=line-too-long, disable=bare-except, consider-using-generator
16-
""" Utils that are only interesting to MaxText. """
16+
"""Utils that are only interesting to MaxText."""
1717

1818
import functools
1919
import pickle
@@ -39,7 +39,13 @@
3939
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
4040

4141
from maxtext.configs import pyconfig
42-
from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, ShardMode
42+
from maxtext.common.common_types import (
43+
DecoderBlockType,
44+
MODEL_MODE_PREFILL,
45+
MODEL_MODE_AUTOREGRESSIVE,
46+
ReorderStrategy,
47+
ShardMode,
48+
)
4349
from maxtext.configs import types
4450
from maxtext.inference.page_manager import PageState
4551
from maxtext.common import checkpointing
@@ -113,19 +119,27 @@ def get_functional_eval_with_signature(eval_step, data_sharding, state_mesh_shar
113119
return functional_eval, in_shardings, out_shardings, static_argnums, donate_argnums
114120

115121

116-
def shard_reorder_causal_load_balanced(batch, cp_size, shard_mode):
122+
def shard_reorder_causal_load_balanced(
123+
batch, cp_size, shard_mode, reorder_strategy=ReorderStrategy.DUAL_CHUNK_SWAP, hardware="tpu"
124+
):
117125
"""Shard the output of the reordered sequence."""
118-
reordered = max_utils.reorder_causal_load_balanced(batch, cp_size)
126+
reordered = max_utils.reorder_causal_load_balanced(batch, cp_size, reorder_strategy, hardware)
119127
for _, v in batch.items():
120128
if isinstance(v, jax.Array):
121129
reordered = sharding.maybe_shard_with_name(reordered, v.sharding, shard_mode)
122130
break
123131
return reordered
124132

125133

126-
def get_reorder_callable(cp_size, shard_mode):
134+
def get_reorder_callable(cp_size, shard_mode, reorder_strategy=ReorderStrategy.DUAL_CHUNK_SWAP, hardware="tpu"):
127135
"""Creates a callable that can be used with map() to reorder batches."""
128-
return functools.partial(shard_reorder_causal_load_balanced, cp_size=cp_size, shard_mode=shard_mode)
136+
return functools.partial(
137+
shard_reorder_causal_load_balanced,
138+
cp_size=cp_size,
139+
shard_mode=shard_mode,
140+
reorder_strategy=reorder_strategy,
141+
hardware=hardware,
142+
)
129143

130144

131145
def get_shaped_batch(config):

0 commit comments

Comments
 (0)