Skip to content

Commit b2d31df

Browse files
Merge pull request #416 from AI-Hypercomputer:plumb_vmem_limit
PiperOrigin-RevId: 930917891
2 parents 4a3ec4f + 77a117e commit b2d31df

3 files changed

Lines changed: 49 additions & 0 deletions

File tree

src/maxdiffusion/kernels/custom_splash_attention.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ def _splash_attention_forward(
359359
kv_seq_len: int | None = None,
360360
use_base2_exp: bool = True,
361361
use_experimental_scheduler: bool = False,
362+
vmem_limit_bytes: int | None = None,
362363
):
363364
num_q_heads, padded_q_seq_len, head_dim_qk = q.shape
364365
head_dim_v = v.shape[-1]
@@ -429,6 +430,7 @@ def v_index_map(h, i, j, *_):
429430
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": use_experimental_scheduler},
430431
disable_bounds_checks=True,
431432
skip_device_barrier=True,
433+
vmem_limit_bytes=vmem_limit_bytes,
432434
),
433435
out_shape=out_shapes,
434436
)(q, k, v)
@@ -446,6 +448,7 @@ def _splash_attention_forward_mhpt(
446448
kv_seq_len: int | None = None,
447449
use_base2_exp: bool = True,
448450
use_experimental_scheduler: bool = False,
451+
vmem_limit_bytes: int | None = None,
449452
):
450453
num_q_heads, padded_q_seq_len, head_dim_qk = q.shape
451454
head_dim_v = v.shape[-1]
@@ -518,6 +521,7 @@ def out_index_map(h, i, j, *_):
518521
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": use_experimental_scheduler},
519522
disable_bounds_checks=True,
520523
skip_device_barrier=True,
524+
vmem_limit_bytes=vmem_limit_bytes,
521525
),
522526
out_shape=out_shapes,
523527
)(q, k, v)
@@ -532,6 +536,7 @@ def make_splash_mha(
532536
heads_per_tile: int = 1,
533537
use_base2_exp: bool = True,
534538
use_experimental_scheduler: bool = False,
539+
vmem_limit_bytes: int | None = None,
535540
):
536541
def _splash_attention(q, k, v):
537542
if heads_per_tile > 1:
@@ -546,6 +551,7 @@ def _splash_attention(q, k, v):
546551
kv_seq_len=orig_kv_seq_len,
547552
use_base2_exp=use_base2_exp,
548553
use_experimental_scheduler=use_experimental_scheduler,
554+
vmem_limit_bytes=vmem_limit_bytes,
549555
)
550556
return _splash_attention_forward(
551557
q,
@@ -557,6 +563,7 @@ def _splash_attention(q, k, v):
557563
kv_seq_len=orig_kv_seq_len,
558564
use_base2_exp=use_base2_exp,
559565
use_experimental_scheduler=use_experimental_scheduler,
566+
vmem_limit_bytes=vmem_limit_bytes,
560567
)
561568

562569
return _splash_attention
@@ -581,6 +588,7 @@ def tpu_custom_attention(
581588
heads_per_tile=None,
582589
use_base2_exp=True,
583590
use_experimental_scheduler=False,
591+
vmem_limit_bytes=None,
584592
flash_block_sizes=None,
585593
):
586594
_LOG2_E = 1.44269504
@@ -592,6 +600,7 @@ def tpu_custom_attention(
592600
block_kv_compute = flash_block_sizes.get("block_kv_compute", block_kv_compute)
593601
block_kv_compute_in = flash_block_sizes.get("block_kv_compute_in", block_kv_compute_in)
594602
heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile)
603+
vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes)
595604

596605
block_q = block_q if block_q is not None else DEFAULT_BQSIZE
597606
block_kv = block_kv if block_kv is not None else DEFAULT_BKVSIZE
@@ -639,6 +648,7 @@ def _kernel_3d(q_3d, k_3d, v_3d):
639648
heads_per_tile=heads_per_tile,
640649
use_base2_exp=use_base2_exp,
641650
use_experimental_scheduler=use_experimental_scheduler,
651+
vmem_limit_bytes=vmem_limit_bytes,
642652
)
643653
out = splash_kernel(
644654
q_3d_padded.astype(jnp.bfloat16),
@@ -706,6 +716,7 @@ def make_custom_splash_sdpa(mesh, env, **kwargs):
706716
use_k_smooth = kwargs.get("use_k_smooth", True)
707717
use_base2_exp = kwargs.get("use_base2_exp", True)
708718
use_experimental_scheduler = kwargs.get("use_experimental_scheduler", False)
719+
vmem_limit_bytes = kwargs.get("vmem_limit_bytes", None)
709720

710721
def _simple_attention(q, k, v, scale=None):
711722
s = scale if scale is not None else 1.0 / math.sqrt(q.shape[-1])
@@ -747,6 +758,7 @@ def _sdpa(
747758
heads_per_tile=hpt,
748759
use_base2_exp=use_base2_exp,
749760
use_experimental_scheduler=use_experimental_scheduler,
761+
vmem_limit_bytes=vmem_limit_bytes,
750762
flash_block_sizes=flash_block_sizes,
751763
)
752764
return env.j2t_iso(result)

src/maxdiffusion/max_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
# pylint: disable=bare-except, consider-using-generator
1919
""" Common Max Utils needed by multiple modules"""
20+
import dataclasses
2021
import functools
2122
from functools import partial, reduce
2223
from contextlib import nullcontext
@@ -612,12 +613,44 @@ def value_or_none(flash_block_sizes, key):
612613
return None
613614

614615

616+
@dataclasses.dataclass(frozen=True)
617+
class CustomFlashBlockSizes:
618+
"""Hashable carrier for the custom splash kernel's block sizes.
619+
620+
The JAX `splash_attention_kernel.BlockSizes` is frozen + slotted and only has
621+
fields for block_q/block_kv/block_kv_compute — it silently drops
622+
block_kv_compute_in, heads_per_tile, and vmem_limit_bytes, which the custom
623+
kernel needs. A plain dict would carry them but is unhashable (it ends up in
624+
nnx's static graphdef, which jit requires to be hashable). This frozen
625+
dataclass is hashable and is read via getattr in wrap_ulysses_attention.
626+
"""
627+
628+
block_q: int | None = None
629+
block_kv: int | None = None
630+
block_kv_compute: int | None = None
631+
block_kv_compute_in: int | None = None
632+
heads_per_tile: int | None = None
633+
vmem_limit_bytes: int | None = None
634+
635+
615636
def get_flash_block_sizes(config):
616637
"""Create custom flash attention BlockSizes."""
617638
flash_block_sizes = None
618639
if len(config.flash_block_sizes.keys()) > 0:
619640
attention_is_tokamax = "tokamax" in config.attention
620641
user_block_sizes: Dict[str, int] = config.flash_block_sizes
642+
# The custom splash kernel reads flash_block_sizes via getattr and needs
643+
# fields the JAX BlockSizes dataclass cannot hold. Return a frozen, hashable
644+
# carrier so they survive the trip to wrap_ulysses_attention.
645+
if "custom" in config.attention:
646+
return CustomFlashBlockSizes(
647+
block_q=user_block_sizes.get("block_q"),
648+
block_kv=user_block_sizes.get("block_kv"),
649+
block_kv_compute=user_block_sizes.get("block_kv_compute"),
650+
block_kv_compute_in=user_block_sizes.get("block_kv_compute_in"),
651+
heads_per_tile=user_block_sizes.get("heads_per_tile"),
652+
vmem_limit_bytes=user_block_sizes.get("vmem_limit_bytes"),
653+
)
621654
if attention_is_tokamax:
622655
max_logging.log(
623656
"Tokamax kernel specified, Note: Tokamax only supports fused backward kernel."

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ def wrap_ulysses_attention(query, key, value):
581581
bkv_compute = 1024
582582
bkv_compute_in = 1024
583583
heads_per_tile = 1
584+
vmem_limit_bytes = None
584585

585586
if flash_block_sizes is not None:
586587
if isinstance(flash_block_sizes, dict):
@@ -589,12 +590,14 @@ def wrap_ulysses_attention(query, key, value):
589590
bkv_compute = flash_block_sizes.get("block_kv_compute", bkv_compute)
590591
bkv_compute_in = flash_block_sizes.get("block_kv_compute_in", bkv_compute_in)
591592
heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile)
593+
vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes)
592594
else:
593595
bq = getattr(flash_block_sizes, "block_q", bq)
594596
bkv = getattr(flash_block_sizes, "block_kv", bkv)
595597
bkv_compute = getattr(flash_block_sizes, "block_kv_compute", bkv_compute)
596598
bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", bkv_compute_in)
597599
heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile)
600+
vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", vmem_limit_bytes)
598601

599602
if use_base2_exp:
600603
query = query * LOG2E
@@ -613,6 +616,7 @@ def wrap_ulysses_attention(query, key, value):
613616
heads_per_tile=heads_per_tile,
614617
use_base2_exp=use_base2_exp,
615618
use_experimental_scheduler=use_experimental_scheduler,
619+
vmem_limit_bytes=vmem_limit_bytes,
616620
)
617621

618622
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0))

0 commit comments

Comments
 (0)