@@ -128,15 +128,27 @@ def _unflatten_heads(tensor, heads):
128128 return tensor
129129
130130
131- def _reshape_data_for_flash (tensor , heads ):
131+ def _reshape_data_for_flash (tensor , heads , num_context_shards = 1 ):
132132 """
133133 Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
134134 Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
135135 blocks is divisible by the number of shards.
136136 """
137137 if tensor .ndim != 4 :
138138 tensor = _unflatten_heads (tensor , heads )
139- return tensor
139+
140+ org_seq_len = tensor .shape [2 ]
141+
142+ # Pad sequence dimension so it is evenly divisible by the context mesh axis,
143+ # which shard_map requires.
144+ if num_context_shards <= 1 :
145+ return tensor , org_seq_len
146+ rem = org_seq_len % num_context_shards
147+ if rem == 0 :
148+ return tensor , org_seq_len
149+ pad_width = [(0 , 0 )] * tensor .ndim
150+ pad_width [2 ] = (0 , num_context_shards - rem )
151+ return jnp .pad (tensor , pad_width ), org_seq_len
140152
141153
142154def _pad_data_for_flash (tensor , heads , flash_block_size , num_shards : int = 1 ):
@@ -145,7 +157,7 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
145157 Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
146158 blocks is divisible by the number of shards.
147159 """
148- tensor = _reshape_data_for_flash (tensor , heads )
160+ tensor , _ = _reshape_data_for_flash (tensor , heads )
149161
150162 # Pad head_dim to 128 if less than that.
151163 kv_size = tensor .shape [- 1 ]
@@ -255,9 +267,10 @@ def _tpu_flash_attention(
255267 use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
256268 )
257269 num_context_shards = mesh .shape ["context" ]
258- query = _reshape_data_for_flash (query , heads )
259- key = _reshape_data_for_flash (key , heads )
260- value = _reshape_data_for_flash (value , heads )
270+ query , orig_q_seq_len = _reshape_data_for_flash (query , heads , num_context_shards )
271+ key , _ = _reshape_data_for_flash (key , heads , num_context_shards )
272+ value , _ = _reshape_data_for_flash (value , heads , num_context_shards )
273+
261274 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
262275 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
263276
@@ -401,6 +414,8 @@ def ring_scan_body(carry, _):
401414 f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_context: { devices_in_data_context } "
402415 )
403416 x = wrap_flash_attention (query , key , value )
417+ # Trim back to original sequence length after context-axis padding.
418+ x = x [:, :, :orig_q_seq_len , :]
404419 x = _reshape_heads_to_head_dim (x )
405420
406421 return x
0 commit comments