Skip to content

Commit 32beb26

Browse files
committed
feat: support softmax function
1 parent 6ccad79 commit 32beb26

2 files changed

Lines changed: 12 additions & 1 deletion

File tree

cs336_basics/softmax.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
from jaxtyping import Float
3+
4+
def softmax(tensor: Float[torch.Tensor, "..."], dim: int) -> Float[torch.Tensor, "..."]:
5+
max = tensor.max(dim=dim, keepdim=True).values
6+
tensor_shifted = tensor - max
7+
exp_tensor = torch.exp(tensor_shifted)
8+
exp_sum = exp_tensor.sum(dim=dim, keepdim=True)
9+
result = exp_tensor / exp_sum
10+
return result

tests/adapters.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from cs336_basics.rmsnorm import RMSNorm
1717
from cs336_basics.swiglu import SwiGLU
1818
from cs336_basics.rope import RoPE
19+
from cs336_basics.softmax import softmax
1920

2021
def run_linear(
2122
d_in: int,
@@ -446,7 +447,7 @@ def run_softmax(in_features: Float[Tensor, " ..."], dim: int) -> Float[Tensor, "
446447
Float[Tensor, "..."]: Tensor of with the same shape as `in_features` with the output of
447448
softmax normalizing the specified `dim`.
448449
"""
449-
raise NotImplementedError
450+
return softmax(in_features, dim)
450451

451452

452453
def run_cross_entropy(inputs: Float[Tensor, " batch_size vocab_size"], targets: Int[Tensor, " batch_size"]) -> Float[Tensor, ""]:

0 commit comments

Comments
 (0)