Skip to content

Commit d1a2b24

Browse files
authored
Revert "Integrate tokamax ring attention as optional attention kernel for WAN 2.1" (#305)
This reverts commit f68c7b0. Co-authored-by: Elisa Tsai <elisatsai@google.com>
1 parent d69f3c7 commit d1a2b24

2 files changed

Lines changed: 40 additions & 72 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 36 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
2828
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2929
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
30-
from tokamax._src.ops.experimental.tpu.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
3130
from einops import rearrange
3231
from .. import common_types, max_logging
3332

@@ -305,92 +304,62 @@ def wrap_flash_attention(query, key, value):
305304
mask=mask,
306305
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
307306
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
308-
save_residuals=True if "ring" in attention_kernel else False,
309-
)
310-
elif attention_kernel == "tokamax_ring":
311-
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),)
312-
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
313-
mask=mask,
314-
is_mqa=False,
315-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
316-
save_residuals=True,
317-
ring_axis="fsdp",
307+
save_residuals=True if attention_kernel == "ring" else False,
318308
)
319309
else:
320310
splash_kernel = splash_attention_kernel.make_splash_mha(
321311
mask=multi_head_mask,
322312
head_shards=1, # the sizes of the axis is sharding over heads
323313
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
324314
block_sizes=block_sizes,
325-
save_residuals=True if "ring" in attention_kernel else False,
315+
save_residuals=True if attention_kernel == "ring" else False,
326316
residual_checkpoint_name=residual_checkpoint_name
327317
)
318+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
328319

329-
if attention_kernel == "tokamax_ring":
330-
# For tokamax_ring, use the kernel directly without vmap
331-
# The ring attention kernel handles the ring topology internally
332-
if not mask_padding_tokens:
333-
segment_ids = None
334-
attention_output = splash_kernel(
335-
fwd_mask_info=None,
336-
dkv_mask_info=None,
337-
q=query,
338-
k=key,
339-
v=value,
340-
segment_ids=segment_ids,
341-
is_mqa=False,
342-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
343-
mask_value=-jnp.inf,
344-
mask_function=None,
345-
fwd_mask_sparsity=1.0,
346-
save_residuals=True,
347-
)
320+
if not mask_padding_tokens:
321+
segment_ids = None
322+
if attention_kernel in ["flash", "tokamax_flash"]:
323+
attention_output = vmapped_splash(query, key, value, segment_ids)
348324
else:
349-
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
350-
351-
if not mask_padding_tokens:
352-
segment_ids = None
353-
if attention_kernel in ["flash", "tokamax_flash"]:
354-
attention_output = vmapped_splash(query, key, value, segment_ids)
355-
else:
356-
if num_fsdp_shards > 1:
357-
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
358-
m = lse.astype(jnp.float32)
359-
l = jnp.exp(lse - m)
360-
o = out.astype(jnp.float32) * l[..., None]
325+
if num_fsdp_shards > 1:
326+
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
327+
m = lse.astype(jnp.float32)
328+
l = jnp.exp(lse - m)
329+
o = out.astype(jnp.float32) * l[..., None]
361330

362-
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
331+
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
363332

364-
k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
365-
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)
333+
k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
334+
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)
366335

367-
def ring_scan_body(carry, _):
368-
m, l, o, k_current, v_current = carry
369-
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
370-
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)
336+
def ring_scan_body(carry, _):
337+
m, l, o, k_current, v_current = carry
338+
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
339+
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)
371340

372-
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
341+
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
373342

374-
m_chunk = lse_chunk.astype(jnp.float32)
375-
m_old = m
376-
m = jnp.maximum(m_old, m_chunk)
343+
m_chunk = lse_chunk.astype(jnp.float32)
344+
m_old = m
345+
m = jnp.maximum(m_old, m_chunk)
377346

378-
exp_m_diff = jnp.exp(m_old - m)
379-
exp_m_chunk_diff = jnp.exp(m_chunk - m)
347+
exp_m_diff = jnp.exp(m_old - m)
348+
exp_m_chunk_diff = jnp.exp(m_chunk - m)
380349

381-
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
382-
o = o * exp_m_diff[..., None]
383-
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
350+
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
351+
o = o * exp_m_diff[..., None]
352+
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
384353

385-
# Return the updated state for the next iteration
386-
return (m, l, o, k_next, v_next), None
354+
# Return the updated state for the next iteration
355+
return (m, l, o, k_next, v_next), None
387356

388-
initial_carry = (m, l, o, k1, v1)
389-
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
357+
initial_carry = (m, l, o, k1, v1)
358+
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
390359

391-
attention_output = o_final / l_final[..., None]
392-
else:
393-
raise ValueError("ring attention requires fsdp > 1")
360+
attention_output = o_final / l_final[..., None]
361+
else:
362+
raise ValueError("ring attention requires fsdp > 1")
394363

395364
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
396365

@@ -566,7 +535,7 @@ def _apply_attention(
566535
mask_padding_tokens=mask_padding_tokens,
567536
residual_checkpoint_name=residual_checkpoint_name,
568537
)
569-
elif "ring" in attention_kernel:
538+
elif attention_kernel == "ring":
570539
return _tpu_flash_attention(
571540
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel,
572541
mask_padding_tokens=mask_padding_tokens,
@@ -577,7 +546,6 @@ def _apply_attention(
577546
raise ValueError(f"Unexpected attention kernel {attention_kernel=}.")
578547

579548

580-
581549
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
582550
"""Multi-head dot product attention with a limited number of queries."""
583551
num_kv, num_heads, k_features = key.shape[-3:]

src/maxdiffusion/pyconfig.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def user_init(raw_keys):
195195

196196
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
197197
# Verify qkv is sharded across sequence.
198-
if "ring" in raw_keys["attention"] or raw_keys["attention_sharding_uniform"]:
199-
max_logging.log(f"Adding sequence sharding to q and kv if not already present because '{raw_keys['attention']}' contains 'ring' or {raw_keys['attention_sharding_uniform']} is set.")
198+
if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]:
199+
max_logging.log(f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set.")
200200
logical_axis_rules = list(raw_keys["logical_axis_rules"])
201201
max_logging.log(f"Initial logical axis rules: {logical_axis_rules}")
202202
new_rules = []
@@ -206,12 +206,12 @@ def user_init(raw_keys):
206206
logical_axis_rules.append(q_seq_sharding)
207207
if kv_seq_sharding not in logical_axis_rules:
208208
logical_axis_rules.append(kv_seq_sharding)
209-
if "ring" in raw_keys["attention"]:
209+
if raw_keys["attention"] == "ring":
210210
for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES:
211211
if ring_attention_axis_rule not in logical_axis_rules:
212212
max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}")
213213
new_rules.append(ring_attention_axis_rule)
214-
else: # attention contains 'flash' but sequence parallel sharding requested for both self and cross attention
214+
else: # attention =flash but sequence parallel sharding requested for both self and cross attention
215215
for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES:
216216
if seq_parallel_axis_rule not in logical_axis_rules:
217217
max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}")

0 commit comments

Comments
 (0)