@@ -306,7 +306,7 @@ def pallas_attention(
306306 head_dim = hl .specialize (q_in .size (- 1 ))
307307 assert head_dim == k_in .size (- 1 ) == v_in .size (- 1 )
308308 q_view = q_in .reshape ([- 1 , m_dim , head_dim ])
309- k_view = k_in .reshape ([- 1 , n_dim , head_dim ]). transpose ( 1 , 2 )
309+ k_view = k_in .reshape ([- 1 , n_dim , head_dim ])
310310 v_view = v_in .reshape ([- 1 , n_dim , head_dim ])
311311 out = torch .empty_like (q_view )
312312 sm_scale = 1.0 / math .sqrt (head_dim )
@@ -319,9 +319,10 @@ def pallas_attention(
319319 for tile_n in hl .tile (v_view .size (1 )):
320320 # scaling Q in-loop on-demand reduces spillage, faster than keeping pre-scaled Q
321321 q_scaled = q * qk_scale
322- k = k_view [tile_b , :, tile_n ]
322+ k = k_view [tile_b , tile_n , :]
323+ # Keep scores in fp32 to match SDPA tolerances on bf16/fp16 inputs.
323324 # same as hl.dot(q, k, out_dtype=torch.float32)
324- qk = torch .bmm (q_scaled , k , torch .float32 )
325+ qk = torch .bmm (q_scaled , k . transpose ( 1 , 2 ) , torch .float32 )
325326 m_ij = torch .maximum (m_i , torch .amax (qk , - 1 ))
326327 qk = qk - m_ij [:, :, None ]
327328 p = torch .exp2 (qk )
@@ -1059,7 +1060,7 @@ def test_attention_fori_loop_correctness_head_dim_256(self) -> None:
10591060 "((4, 128, 128), 'jnp.float32', 'vmem'), "
10601061 "((4, 128, 128), 'jnp.float32', 'vmem'), "
10611062 "((4, 128, 256), 'jnp.float32', 'vmem'), "
1062- "((4, 256, 128 ), 'jnp.float32', 'vmem'), "
1063+ "((4, 128, 256 ), 'jnp.float32', 'vmem'), "
10631064 "((), None, 'dma_semaphore'), "
10641065 "((4, 128, 256), 'jnp.float32', 'vmem'), "
10651066 "((), None, 'dma_semaphore')]" ,
0 commit comments