Skip to content

Commit e770830

Browse files
committed
feat: support scaled dot-product attention
1 parent 32beb26 commit e770830

2 files changed

Lines changed: 23 additions & 1 deletion

File tree

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch
2+
from jaxtyping import Float
3+
from cs336_basics.softmax import softmax
4+
5+
def scaled_dot_product_attention(
6+
Q: Float[torch.Tensor, "... queries d_k"],
7+
K: Float[torch.Tensor, "... keys d_k"],
8+
V: Float[torch.Tensor, "... values d_v"],
9+
mask: Float[torch.Tensor, "... queries keys"] | None = None
10+
) -> Float[torch.Tensor, "... queries d_v"]:
11+
d_k = Q.shape[-1]
12+
sqrt_d_k: int = d_k ** 0.5
13+
scores = Q @ K.transpose(-2, -1) / sqrt_d_k # (..., queries, keys)
14+
15+
if mask is not None:
16+
scores = scores.masked_fill(~mask, float("-inf"))
17+
18+
attention = softmax(scores, dim=-1) # (..., queries, keys)
19+
# (..., queries, keys) @ (..., values, d_v) -> (..., queries, d_v)
20+
result = attention @ V
21+
return result

tests/adapters.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from cs336_basics.swiglu import SwiGLU
1818
from cs336_basics.rope import RoPE
1919
from cs336_basics.softmax import softmax
20+
from cs336_basics.scaled_dot_product_attention import scaled_dot_product_attention
2021

2122
def run_linear(
2223
d_in: int,
@@ -116,7 +117,7 @@ def run_scaled_dot_product_attention(
116117
Returns:
117118
Float[Tensor, " ... queries d_v"]: Output of SDPA
118119
"""
119-
raise NotImplementedError
120+
return scaled_dot_product_attention(Q, K, V, mask)
120121

121122

122123
def run_multihead_self_attention(

0 commit comments

Comments
 (0)