@@ -306,7 +306,7 @@ def pallas_attention(
306306 head_dim = hl .specialize (q_in .size (- 1 ))
307307 assert head_dim == k_in .size (- 1 ) == v_in .size (- 1 )
308308 q_view = q_in .reshape ([- 1 , m_dim , head_dim ])
309- k_view = k_in .reshape ([- 1 , n_dim , head_dim ]). transpose ( 1 , 2 )
309+ k_view = k_in .reshape ([- 1 , n_dim , head_dim ])
310310 v_view = v_in .reshape ([- 1 , n_dim , head_dim ])
311311 out = torch .empty_like (q_view )
312312 sm_scale = 1.0 / math .sqrt (head_dim )
@@ -319,9 +319,10 @@ def pallas_attention(
319319 for tile_n in hl .tile (v_view .size (1 )):
320320 # scaling Q in-loop on-demand reduces spillage, faster than keeping pre-scaled Q
321321 q_scaled = q * qk_scale
322- k = k_view [tile_b , :, tile_n ]
322+ k = k_view [tile_b , tile_n , :]
323+ # Keep scores in fp32 to match SDPA tolerances on bf16/fp16 inputs.
323324 # same as hl.dot(q, k, out_dtype=torch.float32)
324- qk = torch .bmm (q_scaled , k , torch .float32 )
325+ qk = torch .bmm (q_scaled , k . transpose ( 1 , 2 ) , torch .float32 )
325326 m_ij = torch .maximum (m_i , torch .amax (qk , - 1 ))
326327 qk = qk - m_ij [:, :, None ]
327328 p = torch .exp2 (qk )
0 commit comments