1818import flax .linen as nn
1919from flax import nnx
2020import jax
21+ from jax .ad_checkpoint import checkpoint_name
2122from jax .sharding import PartitionSpec
2223import jax .numpy as jnp
2324from jax .experimental import shard_map
@@ -187,30 +188,6 @@ def _tpu_flash_attention(
187188 value , _ , _ = _reshape_data_for_flash (value , heads , block_sizes .block_kv_compute , num_fsdp_shards )
188189 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
189190 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
190- flash_axis_names_splash_kernel : AxisNames = (HEAD , LENGTH )
191- axis_names_splash_kernel = nn .logical_to_mesh_axes (flash_axis_names_splash_kernel )
192- named_sharding = jax .sharding .NamedSharding (mesh , axis_names_splash_kernel )
193-
194- shard_head_size = mesh .shape ["tensor" ]
195-
196- @functools .partial (
197- jax .jit ,
198- static_argnames = ["multi_head_mask" , "shard_head_size" ],
199- )
200- def wrap_splash_kernel (multi_head_mask , shard_head_size = 1 ):
201- splash_kernel = splash_attention_kernel .make_splash_mha (
202- mask = multi_head_mask ,
203- head_shards = shard_head_size , # the sizes of the axis is sharding over heads
204- q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
205- block_sizes = block_sizes ,
206- )
207- return splash_kernel
208-
209- mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
210-
211- multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
212- splash_kernel = wrap_splash_kernel (multi_head_mask , int (shard_head_size ))
213- segment_axis_names_splash_kernel = splash_kernel .manual_sharding_spec (named_sharding )
214191
215192 @functools .partial (
216193 shard_map .shard_map ,
@@ -219,12 +196,21 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
219196 q_axis_names ,
220197 kv_axis_names ,
221198 kv_axis_names ,
222- segment_axis_names_splash_kernel ,
223199 ),
224200 out_specs = q_axis_names ,
225201 check_rep = False ,
226202 )
227- def wrap_flash_attention (query , key , value , splash_kernel ):
203+ def wrap_flash_attention (query , key , value ):
204+ mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
205+ multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
206+ # make_splash_mha is wrapped around shardmap and seq and head is already
207+ # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
208+ splash_kernel = splash_attention_kernel .make_splash_mha (
209+ mask = multi_head_mask ,
210+ head_shards = 1 , # the sizes of the axis is sharding over heads
211+ q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
212+ block_sizes = block_sizes ,
213+ )
228214 attention_output = jax .vmap (splash_kernel )(query , key , value )
229215 return attention_output
230216
@@ -236,7 +222,7 @@ def wrap_flash_attention(query, key, value, splash_kernel):
236222 "Warning, batch dimension should be shardable among the devices in data and fsdp"
237223 f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_fsdp: { devices_in_data_fsdp } "
238224 )
239- x = wrap_flash_attention (query , key , value , splash_kernel )
225+ x = wrap_flash_attention (query , key , value )
240226 x = x [:, :, :query_seq_len , :kv_size ]
241227 x = _reshape_heads_to_head_dim (x )
242228
@@ -632,7 +618,7 @@ def __init__(
632618 use_memory_efficient_attention : bool = False ,
633619 split_head_dim : bool = False ,
634620 attention_kernel : str = "flash" ,
635- flash_min_seq_length : int = 4096 ,
621+ flash_min_seq_length : int = 0 ,
636622 flash_block_sizes : BlockSizes = None ,
637623 mesh : jax .sharding .Mesh = None ,
638624 dtype : jnp .dtype = jnp .float32 ,
@@ -809,12 +795,16 @@ def __call__(
809795 query_proj = _unflatten_heads (query_proj , self .heads )
810796 key_proj = _unflatten_heads (key_proj , self .heads )
811797 value_proj = _unflatten_heads (value_proj , self .heads )
798+ # output of _unflatten_heads Batch, heads, seq_len, head_dim
812799 query_proj , key_proj = self ._apply_rope (query_proj , key_proj , rotary_emb )
813800
801+ query_proj = checkpoint_name (query_proj , "query_proj" )
802+ key_proj = checkpoint_name (key_proj , "key_proj" )
803+ value_proj = checkpoint_name (value_proj , "value_proj" )
814804 attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
815805
816806 attn_output = attn_output .astype (dtype = dtype )
817-
807+ attn_output = checkpoint_name ( attn_output , "attn_output" )
818808 hidden_states = self .proj_attn (attn_output )
819809 return hidden_states
820810
0 commit comments