@@ -345,8 +345,8 @@ def backward_computation(q: jnp.ndarray, k: jnp.ndarray, w: jnp.ndarray, d_score
345345
346346 q_spec = pl .BlockSpec ((None , bT , H , D_padded ), lambda b , t : (b , t , 0 , 0 ))
347347 w_spec = pl .BlockSpec ((None , bT , H_padded_w ), lambda b , t : (b , t , 0 ))
348- k_spec_any = pl .BlockSpec (memory_space = pltpu . MemorySpace . ANY )
349- d_score_spec_any = pl .BlockSpec (memory_space = pltpu . MemorySpace . ANY )
348+ k_spec_any = pl .BlockSpec (memory_space = None )
349+ d_score_spec_any = pl .BlockSpec (memory_space = None )
350350
351351 d_q_spec = pl .BlockSpec ((None , bT , H , D_padded ), lambda b , t : (b , t , 0 , 0 ))
352352 d_w_spec = pl .BlockSpec ((None , bT , H_padded_w ), lambda b , t : (b , t , 0 ))
@@ -376,8 +376,8 @@ def backward_computation(q: jnp.ndarray, k: jnp.ndarray, w: jnp.ndarray, d_score
376376 grid_k = (B , S_padded // bS )
377377
378378 k_spec = pl .BlockSpec ((None , bS , D_padded ), lambda b , s : (b , s , 0 ))
379- q_spec_any = pl .BlockSpec (memory_space = pltpu . MemorySpace . ANY )
380- w_spec_any = pl .BlockSpec (memory_space = pltpu . MemorySpace . ANY )
379+ q_spec_any = pl .BlockSpec (memory_space = None )
380+ w_spec_any = pl .BlockSpec (memory_space = None )
381381 # d_score_spec_any reused
382382
383383 d_k_spec = pl .BlockSpec ((None , bS , D_padded ), lambda b , s : (b , s , 0 ))
@@ -703,22 +703,22 @@ def _computation_impl(self, q: jnp.ndarray, k: jnp.ndarray, w: jnp.ndarray, mask
703703
704704 # k: (B, S_padded, D_padded) -> Full array in HBM
705705 # We use ANY memory space, so we must pass the full array and slice manually in the kernel
706- k_spec = pl .BlockSpec (memory_space = pltpu . MemorySpace . ANY )
706+ k_spec = pl .BlockSpec (memory_space = None )
707707
708708 # mask
709709 has_mask = mask is not None
710710 if has_mask :
711711 # mask: (B, T, S) -> Full array in HBM
712- mask_spec = pl .BlockSpec (memory_space = pltpu . MemorySpace . ANY )
712+ mask_spec = pl .BlockSpec (memory_space = None )
713713 else :
714714 # Dummy mask to satisfy Pallas signature
715715 # Create a small dummy mask
716716 dummy_mask = jnp .zeros ((1 , 1 ), dtype = jnp .float32 )
717- mask_spec = pl .BlockSpec (memory_space = pltpu . MemorySpace . ANY )
717+ mask_spec = pl .BlockSpec (memory_space = None )
718718
719719 # Outputs
720720 # o_score: (B, T, S) -> Full array in HBM
721- o_score_spec = pl .BlockSpec (memory_space = pltpu . MemorySpace . ANY )
721+ o_score_spec = pl .BlockSpec (memory_space = None )
722722
723723 out_shape = jax .ShapeDtypeStruct ((B , T_padded , S_padded ), dtype = jnp .float32 )
724724
@@ -854,10 +854,12 @@ def __call__(
854854
855855 if True :
856856 # early return
857+ print ("use kernel implementation" )
857858 weights = self .weights_proj (inputs_q )
858859 weights = weights * (self .n_heads ** - 0.5 ) * self .softmax_scale
859860 return self .computation (q , k , weights , attention_mask , self .config .index_topk )
860861
862+ print ("use JAX implementation" )
861863 # Compute Index Scores
862864 # QK product: relu(q @ k.T), [b, t, s, h]
863865 # Similar to MQA, each key is shared by h query head
0 commit comments