Skip to content

Commit 5e8334f

Browse files
committed
Fix flash attention shard_map for sequence lengths not divisible by context mesh axis
1 parent 1d5d773 commit 5e8334f

2 files changed

Lines changed: 24 additions & 15 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,27 @@ def _unflatten_heads(tensor, heads):
128128
return tensor
129129

130130

131-
def _reshape_data_for_flash(tensor, heads):
131+
def _reshape_data_for_flash(tensor, heads, num_context_shards = 1):
132132
"""
133133
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
134134
Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
135135
blocks is divisible by the number of shards.
136136
"""
137137
if tensor.ndim != 4:
138138
tensor = _unflatten_heads(tensor, heads)
139-
return tensor
139+
140+
org_seq_len = tensor.shape[2]
141+
142+
# Pad sequence dimension so it is evenly divisible by the context mesh axis,
143+
# which shard_map requires.
144+
if num_context_shards <= 1:
145+
return tensor, org_seq_len
146+
rem = org_seq_len % num_context_shards
147+
if rem == 0:
148+
return tensor, org_seq_len
149+
pad_width = [(0, 0)] * tensor.ndim
150+
pad_width[2] = (0, num_context_shards - rem)
151+
return jnp.pad(tensor, pad_width), org_seq_len
140152

141153

142154
def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
@@ -145,7 +157,7 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
145157
Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
146158
blocks is divisible by the number of shards.
147159
"""
148-
tensor = _reshape_data_for_flash(tensor, heads)
160+
tensor, _ = _reshape_data_for_flash(tensor, heads)
149161

150162
# Pad head_dim to 128 if less than that.
151163
kv_size = tensor.shape[-1]
@@ -255,9 +267,10 @@ def _tpu_flash_attention(
255267
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
256268
)
257269
num_context_shards = mesh.shape["context"]
258-
query = _reshape_data_for_flash(query, heads)
259-
key = _reshape_data_for_flash(key, heads)
260-
value = _reshape_data_for_flash(value, heads)
270+
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
271+
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
272+
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
273+
261274
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
262275
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
263276

@@ -401,6 +414,8 @@ def ring_scan_body(carry, _):
401414
f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}"
402415
)
403416
x = wrap_flash_attention(query, key, value)
417+
# Trim back to original sequence length after context-axis padding.
418+
x = x[:, :, :orig_q_seq_len, :]
404419
x = _reshape_heads_to_head_dim(x)
405420

406421
return x

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,7 @@ def prepare_video_coords(
193193
# pixel_coords[:, 0, ...] selects Frame dimension.
194194
# pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W)
195195
frame_coords = pixel_coords[:, 0, ...]
196-
frame_coords = jnp.clip(
197-
frame_coords + self.causal_offset - self.scale_factors[0], min=0
198-
)
196+
frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], min=0)
199197
pixel_coords = pixel_coords.at[:, 0, ...].set(frame_coords / fps)
200198

201199
return pixel_coords
@@ -212,16 +210,12 @@ def prepare_audio_coords(
212210
# 2. Start timestamps
213211
audio_scale_factor = self.scale_factors[0]
214212
grid_start_mel = grid_f * audio_scale_factor
215-
grid_start_mel = jnp.clip(
216-
grid_start_mel + self.causal_offset - audio_scale_factor, min=0
217-
)
213+
grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, min=0)
218214
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate
219215

220216
# 3. End timestamps
221217
grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor
222-
grid_end_mel = jnp.clip(
223-
grid_end_mel + self.causal_offset - audio_scale_factor, min=0
224-
)
218+
grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, min=0)
225219
grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate
226220

227221
# Stack [num_patches, 2]

0 commit comments

Comments
 (0)