Skip to content

Commit 2ddf8ab

Browse files
committed
Fix transformer sharding and cross-attention flash block sizes
1 parent 384d211 commit 2ddf8ab

4 files changed

Lines changed: 120 additions & 42 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 58 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,49 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
190190
return tensor, kv_size, seq_len
191191

192192

193+
def _flash_sequence_length(tensor: Array) -> int:
194+
if tensor.ndim == 3:
195+
return tensor.shape[1]
196+
if tensor.ndim == 4:
197+
return tensor.shape[2]
198+
raise ValueError(f"Flash attention expects rank-3 or rank-4 inputs, got rank {tensor.ndim}.")
199+
200+
201+
def _select_flash_block_sizes(
202+
query: Array,
203+
key: Array,
204+
flash_block_sizes: BlockSizes,
205+
dtype: jnp.dtype,
206+
attention_kernel: str,
207+
) -> BlockSizes:
208+
query_seq_len = _flash_sequence_length(query)
209+
key_seq_len = _flash_sequence_length(key)
210+
211+
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
212+
if key_seq_len != query_seq_len:
213+
kv_max_block_size = ((key_seq_len + 127) // 128) * 128
214+
else:
215+
kv_max_block_size = q_max_block_size
216+
217+
# Keep configured block sizes for self-attention, but let
218+
# cross-attention derive safe KV-aware sizes when q_len != kv_len.
219+
if flash_block_sizes and key_seq_len == query_seq_len:
220+
return flash_block_sizes
221+
222+
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
223+
return splash_attention_kernel.BlockSizes(
224+
block_q=block_size_q,
225+
block_kv_compute=min(kv_max_block_size, key_seq_len),
226+
block_kv=min(kv_max_block_size, key_seq_len),
227+
block_q_dkv=block_size_q,
228+
block_kv_dkv=min(kv_max_block_size, key_seq_len),
229+
block_kv_dkv_compute=min(kv_max_block_size, query_seq_len),
230+
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
231+
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query_seq_len),
232+
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
233+
)
234+
235+
193236
def convert_to_tokamax_splash_config(
194237
block_sizes: BlockSizes,
195238
q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR,
@@ -244,28 +287,7 @@ def _tpu_flash_attention(
244287
) -> jax.Array:
245288
"""TPU Flash Attention"""
246289

247-
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
248-
# This is the case for cross-attn.
249-
if key.shape[1] != query.shape[1]:
250-
kv_max_block_size = ((key.shape[1] + 127) // 128) * 128
251-
else:
252-
kv_max_block_size = q_max_block_size
253-
# ensure that for cross attention we override the block sizes.
254-
if flash_block_sizes and key.shape[1] == query.shape[1]:
255-
block_sizes = flash_block_sizes
256-
else:
257-
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
258-
block_sizes = splash_attention_kernel.BlockSizes(
259-
block_q=block_size_q,
260-
block_kv_compute=min(kv_max_block_size, key.shape[2]),
261-
block_kv=min(kv_max_block_size, key.shape[2]),
262-
block_q_dkv=block_size_q,
263-
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
264-
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
265-
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
266-
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
267-
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
268-
)
290+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
269291
num_context_shards = mesh.shape["context"]
270292
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
271293
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
@@ -717,8 +739,8 @@ def __init__(
717739
dtype=dtype,
718740
param_dtype=weights_dtype,
719741
precision=precision,
720-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", None)),
721-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
742+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
743+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
722744
)
723745
self.act = get_activation(activation_fn)
724746
self.net_2 = nnx.Linear(
@@ -729,8 +751,8 @@ def __init__(
729751
dtype=dtype,
730752
param_dtype=weights_dtype,
731753
precision=precision,
732-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
733-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
754+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", "embed")),
755+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
734756
)
735757

736758
def __call__(self, hidden_states: Array) -> Array:
@@ -979,7 +1001,7 @@ def __init__(
9791001
precision=precision,
9801002
bias_init=nnx.with_partitioning(
9811003
nnx.initializers.zeros,
982-
("embed",),
1004+
("heads",),
9831005
),
9841006
)
9851007

@@ -993,7 +1015,7 @@ def __init__(
9931015
precision=precision,
9941016
bias_init=nnx.with_partitioning(
9951017
nnx.initializers.zeros,
996-
("embed",),
1018+
("heads",),
9971019
),
9981020
)
9991021

@@ -1007,7 +1029,7 @@ def __init__(
10071029
precision=precision,
10081030
bias_init=nnx.with_partitioning(
10091031
nnx.initializers.zeros,
1010-
("embed",),
1032+
("heads",),
10111033
),
10121034
)
10131035

@@ -1021,7 +1043,7 @@ def __init__(
10211043
precision=precision,
10221044
bias_init=nnx.with_partitioning(
10231045
nnx.initializers.zeros,
1024-
("heads",),
1046+
("embed",),
10251047
),
10261048
)
10271049

@@ -1333,11 +1355,13 @@ def setup(self):
13331355
precision=self.precision,
13341356
)
13351357

1358+
proj_attn_kernel_axes = ("heads", "embed")
1359+
13361360
self.proj_attn = nn.Dense(
13371361
self.query_dim,
1338-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes),
1362+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes),
13391363
use_bias=True,
1340-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
1364+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)),
13411365
dtype=self.dtype,
13421366
param_dtype=self.weights_dtype,
13431367
name="i_proj",
@@ -1346,9 +1370,9 @@ def setup(self):
13461370

13471371
self.encoder_proj_attn = nn.Dense(
13481372
self.query_dim,
1349-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes),
1373+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes),
13501374
use_bias=True,
1351-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
1375+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)),
13521376
dtype=self.dtype,
13531377
param_dtype=self.weights_dtype,
13541378
name="e_proj",

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,13 @@ def __init__(
359359
# 1. Define Partitioned Initializers (Logical Axes)
360360
# Q, K, V kernels: [in_features (embed), out_features (heads)]
361361
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads"))
362-
# Q, K, V biases: [out_features (embed)]
363-
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
362+
# Q, K, V biases: [out_features (heads)]
363+
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
364364

365365
# Out kernel: [in_features (heads), out_features (embed)]
366366
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed"))
367-
# Out bias: [out_features (heads)]
368-
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
367+
# Out bias: [out_features (embed)]
368+
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
369369

370370
# Norm scales
371371
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,11 @@ def __init__(
193193
kernel_init=nnx.with_partitioning(
194194
nnx.initializers.xavier_uniform(),
195195
(
196-
"mlp",
197196
"embed",
197+
"mlp",
198198
),
199199
),
200-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
200+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
201201
)
202202

203203
def __call__(self, x: jax.Array) -> jax.Array:
@@ -249,8 +249,8 @@ def __init__(
249249
kernel_init=nnx.with_partitioning(
250250
nnx.initializers.xavier_uniform(),
251251
(
252-
"embed",
253252
"mlp",
253+
"embed",
254254
),
255255
),
256256
)

src/maxdiffusion/tests/attention_test.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import jax
2121
from jax.sharding import Mesh
2222
import jax.numpy as jnp
23-
from ..models.attention_flax import FlaxAttention
23+
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
24+
from ..models.attention_flax import FlaxAttention, _select_flash_block_sizes
2425
from .. import max_utils
2526
from .. import pyconfig
2627

@@ -92,6 +93,59 @@ def test_splash_attention(self):
9293

9394
assert diff_norm < 1.0
9495

96+
def test_cross_attention_overrides_configured_flash_block_sizes(self):
97+
query = jnp.zeros((1, 1024, 256), dtype=jnp.bfloat16)
98+
key = jnp.zeros((1, 257, 256), dtype=jnp.bfloat16)
99+
configured_block_sizes = splash_attention_kernel.BlockSizes(
100+
block_q=384,
101+
block_kv_compute=192,
102+
block_kv=320,
103+
block_q_dkv=256,
104+
block_kv_dkv=288,
105+
block_kv_dkv_compute=160,
106+
block_q_dq=128,
107+
block_kv_dq=96,
108+
use_fused_bwd_kernel=False,
109+
)
110+
111+
block_sizes = _select_flash_block_sizes(
112+
query=query,
113+
key=key,
114+
flash_block_sizes=configured_block_sizes,
115+
dtype=jnp.bfloat16,
116+
attention_kernel="flash",
117+
)
118+
119+
assert block_sizes.block_q == configured_block_sizes.block_q
120+
assert block_sizes.block_q_dkv == configured_block_sizes.block_q
121+
assert block_sizes.block_q_dq == configured_block_sizes.block_q
122+
assert block_sizes.block_kv_compute == 257
123+
assert block_sizes.block_kv == 257
124+
assert block_sizes.block_kv_dkv == 257
125+
assert block_sizes.block_kv_dkv_compute == 384
126+
assert block_sizes.block_kv_dq == 384
127+
128+
def test_default_flash_block_sizes_use_sequence_axis_for_3d_inputs(self):
129+
query = jnp.zeros((1, 128, 4096), dtype=jnp.bfloat16)
130+
key = jnp.zeros((1, 257, 4096), dtype=jnp.bfloat16)
131+
132+
block_sizes = _select_flash_block_sizes(
133+
query=query,
134+
key=key,
135+
flash_block_sizes=None,
136+
dtype=jnp.bfloat16,
137+
attention_kernel="flash",
138+
)
139+
140+
assert block_sizes.block_q == 1024
141+
assert block_sizes.block_kv_compute == 257
142+
assert block_sizes.block_kv == 257
143+
assert block_sizes.block_q_dkv == 1024
144+
assert block_sizes.block_kv_dkv == 257
145+
assert block_sizes.block_kv_dkv_compute == 128
146+
assert block_sizes.block_q_dq == 1024
147+
assert block_sizes.block_kv_dq == 128
148+
95149

96150
if __name__ == "__main__":
97151
absltest.main()

0 commit comments

Comments
 (0)