@@ -359,6 +359,7 @@ def _splash_attention_forward(
359359 kv_seq_len : int | None = None ,
360360 use_base2_exp : bool = True ,
361361 use_experimental_scheduler : bool = False ,
362+ vmem_limit_bytes : int | None = None ,
362363):
363364 num_q_heads , padded_q_seq_len , head_dim_qk = q .shape
364365 head_dim_v = v .shape [- 1 ]
@@ -429,6 +430,7 @@ def v_index_map(h, i, j, *_):
429430 flags = {"XLA_TPU_FORCE_LP_LLO_SCHEDULER" : use_experimental_scheduler },
430431 disable_bounds_checks = True ,
431432 skip_device_barrier = True ,
433+ vmem_limit_bytes = vmem_limit_bytes ,
432434 ),
433435 out_shape = out_shapes ,
434436 )(q , k , v )
@@ -446,6 +448,7 @@ def _splash_attention_forward_mhpt(
446448 kv_seq_len : int | None = None ,
447449 use_base2_exp : bool = True ,
448450 use_experimental_scheduler : bool = False ,
451+ vmem_limit_bytes : int | None = None ,
449452):
450453 num_q_heads , padded_q_seq_len , head_dim_qk = q .shape
451454 head_dim_v = v .shape [- 1 ]
@@ -518,6 +521,7 @@ def out_index_map(h, i, j, *_):
518521 flags = {"XLA_TPU_FORCE_LP_LLO_SCHEDULER" : use_experimental_scheduler },
519522 disable_bounds_checks = True ,
520523 skip_device_barrier = True ,
524+ vmem_limit_bytes = vmem_limit_bytes ,
521525 ),
522526 out_shape = out_shapes ,
523527 )(q , k , v )
@@ -532,6 +536,7 @@ def make_splash_mha(
532536 heads_per_tile : int = 1 ,
533537 use_base2_exp : bool = True ,
534538 use_experimental_scheduler : bool = False ,
539+ vmem_limit_bytes : int | None = None ,
535540):
536541 def _splash_attention (q , k , v ):
537542 if heads_per_tile > 1 :
@@ -546,6 +551,7 @@ def _splash_attention(q, k, v):
546551 kv_seq_len = orig_kv_seq_len ,
547552 use_base2_exp = use_base2_exp ,
548553 use_experimental_scheduler = use_experimental_scheduler ,
554+ vmem_limit_bytes = vmem_limit_bytes ,
549555 )
550556 return _splash_attention_forward (
551557 q ,
@@ -557,6 +563,7 @@ def _splash_attention(q, k, v):
557563 kv_seq_len = orig_kv_seq_len ,
558564 use_base2_exp = use_base2_exp ,
559565 use_experimental_scheduler = use_experimental_scheduler ,
566+ vmem_limit_bytes = vmem_limit_bytes ,
560567 )
561568
562569 return _splash_attention
@@ -581,6 +588,7 @@ def tpu_custom_attention(
581588 heads_per_tile = None ,
582589 use_base2_exp = True ,
583590 use_experimental_scheduler = False ,
591+ vmem_limit_bytes = None ,
584592 flash_block_sizes = None ,
585593):
586594 _LOG2_E = 1.44269504
@@ -592,6 +600,7 @@ def tpu_custom_attention(
592600 block_kv_compute = flash_block_sizes .get ("block_kv_compute" , block_kv_compute )
593601 block_kv_compute_in = flash_block_sizes .get ("block_kv_compute_in" , block_kv_compute_in )
594602 heads_per_tile = flash_block_sizes .get ("heads_per_tile" , heads_per_tile )
603+ vmem_limit_bytes = flash_block_sizes .get ("vmem_limit_bytes" , vmem_limit_bytes )
595604
596605 block_q = block_q if block_q is not None else DEFAULT_BQSIZE
597606 block_kv = block_kv if block_kv is not None else DEFAULT_BKVSIZE
@@ -639,6 +648,7 @@ def _kernel_3d(q_3d, k_3d, v_3d):
639648 heads_per_tile = heads_per_tile ,
640649 use_base2_exp = use_base2_exp ,
641650 use_experimental_scheduler = use_experimental_scheduler ,
651+ vmem_limit_bytes = vmem_limit_bytes ,
642652 )
643653 out = splash_kernel (
644654 q_3d_padded .astype (jnp .bfloat16 ),
@@ -706,6 +716,7 @@ def make_custom_splash_sdpa(mesh, env, **kwargs):
706716 use_k_smooth = kwargs .get ("use_k_smooth" , True )
707717 use_base2_exp = kwargs .get ("use_base2_exp" , True )
708718 use_experimental_scheduler = kwargs .get ("use_experimental_scheduler" , False )
719+ vmem_limit_bytes = kwargs .get ("vmem_limit_bytes" , None )
709720
710721 def _simple_attention (q , k , v , scale = None ):
711722 s = scale if scale is not None else 1.0 / math .sqrt (q .shape [- 1 ])
@@ -747,6 +758,7 @@ def _sdpa(
747758 heads_per_tile = hpt ,
748759 use_base2_exp = use_base2_exp ,
749760 use_experimental_scheduler = use_experimental_scheduler ,
761+ vmem_limit_bytes = vmem_limit_bytes ,
750762 flash_block_sizes = flash_block_sizes ,
751763 )
752764 return env .j2t_iso (result )
0 commit comments