Skip to content

Commit 00911fa

Browse files
committed
feat: support cross_entropy
1 parent c797834 commit 00911fa

2 files changed

Lines changed: 21 additions & 1 deletion

File tree

cs336_basics/cross_entropy.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from jaxtyping import Float, Int
2+
from torch import Tensor, exp, log
3+
4+
def cross_entropy(inputs: Float[Tensor, "batch_size vocab_size"], targets: Int[Tensor, "batch_size"]) -> Float[Tensor, ""]:
5+
max_logits = inputs.max(dim=1, keepdim=True).values # [batch_size, 1]
6+
inputs_stable = inputs - max_logits # [batch_size, vocab_size]
7+
8+
target_logits = inputs_stable.gather(dim=1, index=targets.unsqueeze(dim=1))
9+
10+
exp_inputs = exp(inputs_stable)
11+
sum_exp = exp_inputs.sum(dim=1, keepdim=True)
12+
log_sum_exp = log(sum_exp)
13+
14+
sample_losses = -target_logits + log_sum_exp
15+
16+
avg_loss = sample_losses.flatten().mean()
17+
18+
return avg_loss
19+

tests/adapters.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from cs336_basics.multihead_self_attention import MultiHeadSelfAttention, MultiHeadSelfAttentionWithRoPE
2222
from cs336_basics.transformer_block import TransformerBlock
2323
from cs336_basics.transformer import Transformer
24+
from cs336_basics.cross_entropy import cross_entropy
2425

2526
def run_linear(
2627
d_in: int,
@@ -511,7 +512,7 @@ def run_cross_entropy(inputs: Float[Tensor, " batch_size vocab_size"], targets:
511512
Returns:
512513
Float[Tensor, ""]: The average cross-entropy loss across examples.
513514
"""
514-
raise NotImplementedError
515+
return cross_entropy(inputs, targets)
515516

516517

517518
def run_gradient_clipping(parameters: Iterable[torch.nn.Parameter], max_l2_norm: float) -> None:

0 commit comments

Comments
 (0)