You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add packed (THD) ring attention with hardware-aware reorder dispatch
Enable CP + packing for context_parallel_strategy="ring" with load
balancing. On GPU, uses Transformer Engine's striped reorder for
THD-packed sequences. On TPU/CPU, falls back to pure-JAX reorder_sequence
and never imports TE.
Changes:
- common_types: Add ReorderStrategy enum (AUTO, DUAL_CHUNK_SWAP, STRIPED).
- configs: Add context_parallel_reorder_strategy (default "auto"). Reject
explicit STRIPED on non-GPU at config validation time.
- attention_op: Thread segment_positions through apply_attention,
cudnn_flash_attention, and __call__. Use segment_positions in TE's
SequenceDescriptor for packing. Restrict packing+CP to load-balanced
ring only. Note TE version constraint.
- attentions.py, attention_mla.py, gpt3.py: Pass inputs_positions into
attention_op calls (None for gpt3).
- max_utils: Hardware-dispatched reorder_causal_load_balanced. GPU uses
TE's reorder_causal_load_balancing; TPU/CPU uses reorder_sequence.
TE import is lazy and GPU-only.
- maxtext_utils: Thread reorder_strategy and hardware through
shard_reorder_causal_load_balanced and get_reorder_callable. Default
hardware="tpu" never triggers TE import.
- train_utils: Allow ring+packing; forbid all_gather+packing and
synthetic+packing. Resolve AUTO->STRIPED for packing else
DUAL_CHUNK_SWAP. Pass config.hardware to reorder callable. Build
data_loader after reorder wrapper is applied.
- attention_test_util: Pass cfg_cp.hardware so TPU tests use pure-JAX
reorder. Helper is TPU-oriented and does not model GPU packed behavior.
- tests: Add test_gpu_ring_attention_with_packing (sm90+).
Requires TE with reorder_causal_load_balancing; works with TE <=2.11 or
>=2.14 (incompatible with 2.12 and 2.13 due to a known bug).
0 commit comments