@@ -871,6 +871,7 @@ def apply_attention(
871871 key : Array | KVTensor ,
872872 value : Array | KVTensor ,
873873 decoder_segment_ids : Array | None ,
874+ segment_positions : Array | None ,
874875 lengths : Array | None ,
875876 model_mode : str ,
876877 use_ragged_attention : bool = False ,
@@ -1003,7 +1004,7 @@ def apply_attention(
10031004 Use `dot_product` instead."""
10041005 )
10051006 return (
1006- self .cudnn_flash_attention (query , key , value , decoder_segment_ids , model_mode ),
1007+ self .cudnn_flash_attention (query , key , value , decoder_segment_ids , segment_positions , model_mode ),
10071008 None ,
10081009 None ,
10091010 )
@@ -1513,12 +1514,15 @@ def cudnn_flash_attention(
15131514 key : Array ,
15141515 value : Array ,
15151516 decoder_segment_ids : Array | None ,
1517+ segment_positions : Array | None ,
15161518 model_mode : str = MODEL_MODE_TRAIN ,
15171519 ) -> Array :
15181520 """CUDNN Flash Attention with Transformer Engine.
1519-
1520- 1. Stable API, supports MHA, GQA, SWA, Packing and Context Parallelism 2.
1521- Context Parallelism currently only supports causal masking and no packing
1521+ 1. Stable API, supports MHA, GQA, SWA, Packing and Context Parallelism
1522+ 2. Context Parallelism currently only supports causal masking
1523+ 3. Only Ring attention has packing support with striped load balancing
1524+ (context_parallel_strategy="ring" and context_parallel_load_balance=true)
1525+ 4. Breaks with TE 2.12 and 2.13 (known bug); works with TE stable release <=2.11 or >=2.14.
15221526 """
15231527 # These imports are only meant to work in a GPU build.
15241528 # pylint: disable=import-outside-toplevel
@@ -1528,6 +1532,11 @@ def cudnn_flash_attention(
15281532 _ , _ , _ , head_dim = query .shape # pylint: disable=unused-variable
15291533
15301534 using_context_parallelism = self .mesh .shape [self .config .context_sharding ] > 1
1535+ using_load_balanced_ring_cp = (
1536+ using_context_parallelism
1537+ and self .config .context_parallel_strategy == "ring"
1538+ and self .config .context_parallel_load_balance
1539+ )
15311540
15321541 # Initialize default attention configuration
15331542 sliding_window_size = None
@@ -1541,18 +1550,27 @@ def cudnn_flash_attention(
15411550
15421551 # Handle packing configurations
15431552 if self .config .packing and self .config .dataset_type != "synthetic" :
1553+ if using_context_parallelism and not using_load_balanced_ring_cp :
1554+ raise ValueError ("Packing is only supported for load balanced ring attention with context parallelism." )
15441555 qkv_layout = "THD_THD_THD" # Packed format: 'T3HD', 'THD_T2HD' or 'THD_THD_THD'
15451556 if decoder_segment_ids is None :
15461557 decoder_segment_ids = jnp .ones (shape = query .shape [:2 ], dtype = jnp .int32 )
1547- attn_mask = SequenceDescriptor .from_segment_ids_and_pos (segment_ids = decoder_segment_ids , segment_pos = None )
1558+ attn_mask = SequenceDescriptor .from_segment_ids_and_pos (
1559+ segment_ids = decoder_segment_ids , segment_pos = segment_positions
1560+ )
15481561 # Create dummy SequenceDescriptor for lazy_init
15491562 dummy_segment_ids = jnp .ones (shape = query .shape [:2 ], dtype = jnp .int32 )
1550- dummy_attn_mask = SequenceDescriptor .from_segment_ids_and_pos (segment_ids = dummy_segment_ids , segment_pos = None )
1563+ dummy_attn_mask = SequenceDescriptor .from_segment_ids_and_pos (
1564+ segment_ids = dummy_segment_ids , segment_pos = segment_positions
1565+ )
15511566 max_segments_per_seq = self .config .max_segments_per_seq
15521567 elif using_context_parallelism :
15531568 if self .attention_type == AttentionType .LOCAL_SLIDING :
1554- raise AssertionError ("Sliding window attention is not supported for context parallelism" )
1555- # Context parallelism without packing: only supports causal masking
1569+ raise AssertionError (
1570+ "Sliding window attention requires context parallelism with load-balanced ring strategy "
1571+ "and packing enabled."
1572+ )
1573+ # Context parallelism without packing: only supports causal masking, but not sliding window attention
15561574 attn_mask = None
15571575 dummy_attn_mask = None
15581576 mask_type = "causal"
@@ -2003,6 +2021,7 @@ def __call__(
20032021 key ,
20042022 value ,
20052023 decoder_segment_ids ,
2024+ inputs_positions ,
20062025 model_mode ,
20072026 cached_values = None ,
20082027 previous_chunk = None ,
@@ -2034,6 +2053,7 @@ def __call__(
20342053 key = key ,
20352054 value = value ,
20362055 decoder_segment_ids = decoder_segment_ids ,
2056+ segment_positions = inputs_positions ,
20372057 lengths = None ,
20382058 model_mode = model_mode ,
20392059 use_ragged_attention = self .use_ragged_attention ,
@@ -2059,6 +2079,7 @@ def __call__(
20592079 key = key ,
20602080 value = value ,
20612081 decoder_segment_ids = decoder_segment_ids ,
2082+ segment_positions = inputs_positions ,
20622083 lengths = lengths ,
20632084 model_mode = model_mode ,
20642085 use_ragged_attention = self .use_ragged_attention ,
0 commit comments