Skip to content

Commit 6952672

Browse files
committed
Use dtype access instead of hardcoding
1 parent be1f694 commit 6952672

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

ops/ninetoothed/kernels/scaled_dot_product_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def arrange_k_or_v(input):
3030

3131

3232
def application(q, k, v, scale, o):
33-
q_loaded = (q * scale * 1.44269504089).to(ntl.float16)
33+
q_loaded = (q * scale * 1.44269504089).to(q.dtype)
3434

3535
acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32)
3636
l_i = ntl.full((q.shape[-2],), 1, dtype=ntl.float32)
@@ -44,7 +44,7 @@ def application(q, k, v, scale, o):
4444
l_ij = ntl.sum(p, 1)
4545

4646
alpha = ntl.exp2(m_i - m_ij)
47-
acc = acc * alpha[:, None] + ntl.dot(p.to(ntl.float16), v[i])
47+
acc = acc * alpha[:, None] + ntl.dot(p.to(v.dtype.dtype), v[i])
4848
m_i = m_ij
4949
l_i = l_i * alpha + l_ij
5050

0 commit comments

Comments
 (0)