Skip to content

Commit e887876

Browse files
committed
Update kernel memory
1 parent ec7f1c4 commit e887876

2 files changed

Lines changed: 12 additions & 10 deletions

File tree

src/maxtext/inference/paged_attention_kernel_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,8 +622,8 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
622622
)
623623
in_specs = [
624624
q_block_spec,
625-
pl.BlockSpec(memory_space=pl.MemorySpace.ANY),
626-
pl.BlockSpec(memory_space=pl.MemorySpace.ANY),
625+
pl.BlockSpec(memory_space=None),
626+
pl.BlockSpec(memory_space=None),
627627
]
628628
out_specs = q_block_spec
629629
lm_scratch = pltpu.VMEM(

src/maxtext/layers/attention_mla.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,8 @@ def backward_computation(q: jnp.ndarray, k: jnp.ndarray, w: jnp.ndarray, d_score
345345

346346
q_spec = pl.BlockSpec((None, bT, H, D_padded), lambda b, t: (b, t, 0, 0))
347347
w_spec = pl.BlockSpec((None, bT, H_padded_w), lambda b, t: (b, t, 0))
348-
k_spec_any = pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)
349-
d_score_spec_any = pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)
348+
k_spec_any = pl.BlockSpec(memory_space=None)
349+
d_score_spec_any = pl.BlockSpec(memory_space=None)
350350

351351
d_q_spec = pl.BlockSpec((None, bT, H, D_padded), lambda b, t: (b, t, 0, 0))
352352
d_w_spec = pl.BlockSpec((None, bT, H_padded_w), lambda b, t: (b, t, 0))
@@ -376,8 +376,8 @@ def backward_computation(q: jnp.ndarray, k: jnp.ndarray, w: jnp.ndarray, d_score
376376
grid_k = (B, S_padded // bS)
377377

378378
k_spec = pl.BlockSpec((None, bS, D_padded), lambda b, s: (b, s, 0))
379-
q_spec_any = pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)
380-
w_spec_any = pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)
379+
q_spec_any = pl.BlockSpec(memory_space=None)
380+
w_spec_any = pl.BlockSpec(memory_space=None)
381381
# d_score_spec_any reused
382382

383383
d_k_spec = pl.BlockSpec((None, bS, D_padded), lambda b, s: (b, s, 0))
@@ -703,22 +703,22 @@ def _computation_impl(self, q: jnp.ndarray, k: jnp.ndarray, w: jnp.ndarray, mask
703703

704704
# k: (B, S_padded, D_padded) -> Full array in HBM
705705
# We use ANY memory space, so we must pass the full array and slice manually in the kernel
706-
k_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)
706+
k_spec = pl.BlockSpec(memory_space=None)
707707

708708
# mask
709709
has_mask = mask is not None
710710
if has_mask:
711711
# mask: (B, T, S) -> Full array in HBM
712-
mask_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)
712+
mask_spec = pl.BlockSpec(memory_space=None)
713713
else:
714714
# Dummy mask to satisfy Pallas signature
715715
# Create a small dummy mask
716716
dummy_mask = jnp.zeros((1, 1), dtype=jnp.float32)
717-
mask_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)
717+
mask_spec = pl.BlockSpec(memory_space=None)
718718

719719
# Outputs
720720
# o_score: (B, T, S) -> Full array in HBM
721-
o_score_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)
721+
o_score_spec = pl.BlockSpec(memory_space=None)
722722

723723
out_shape = jax.ShapeDtypeStruct((B, T_padded, S_padded), dtype=jnp.float32)
724724

@@ -854,10 +854,12 @@ def __call__(
854854

855855
if True:
856856
# early return
857+
print("use kernel implementation")
857858
weights = self.weights_proj(inputs_q)
858859
weights = weights * (self.n_heads**-0.5) * self.softmax_scale
859860
return self.computation(q, k, weights, attention_mask, self.config.index_topk)
860861

862+
print("use JAX implementation")
861863
# Compute Index Scores
862864
# QK product: relu(q @ k.T), [b, t, s, h]
863865
# Similar to MQA, each key is shared by h query head

0 commit comments

Comments
 (0)