Skip to content

Commit 7c006da

Browse files
committed
feat: support silu
1 parent 593f5fc commit 7c006da

3 files changed

Lines changed: 10 additions & 3 deletions

File tree

cs336_basics/silu.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from jaxtyping import Float
2+
from torch import Tensor, exp
3+
4+
5+
def SiLU(in_features: Float[Tensor, "..."]) -> Float[Tensor, "..."]:
6+
return in_features * (1 / (1 + exp(-in_features)))

tests/adapters.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from cs336_basics.gradient_clipping import gradient_clipping
2828
from cs336_basics.data_loading import data_loading
2929
from cs336_basics.checkpointing import save_checkpoint, load_checkpoint
30+
from cs336_basics.silu import SiLU
3031

3132
def run_linear(
3233
d_in: int,
@@ -462,7 +463,7 @@ def run_silu(in_features: Float[Tensor, " ..."]) -> Float[Tensor, " ..."]:
462463
Float[Tensor,"..."]: of with the same shape as `in_features` with the output of applying
463464
SiLU to each element.
464465
"""
465-
raise NotImplementedError
466+
return SiLU(in_features)
466467

467468

468469
def run_get_batch(

tests/test_train_bpe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _save_tokenizer_artifacts(vocab, merges, output_dir):
123123
right_unicode = ''.join([byte_to_unicode[byte] for byte in right])
124124
merges_file.write(f"['{left_unicode}', '{right_unicode}']\n")
125125

126-
# @pytest.mark.skip()
126+
@pytest.mark.skip()
127127
def test_train_bpe_on_tiny_story_valid():
128128
start_time = time.time()
129129
input_path = data_folder / "TinyStoriesV2-GPT4-valid.txt"
@@ -136,7 +136,7 @@ def test_train_bpe_on_tiny_story_valid():
136136

137137
assert(end_time - start_time <= 120)
138138

139-
# @pytest.mark.skip()
139+
@pytest.mark.skip()
140140
def test_train_bpe_on_tiny_story_train():
141141
start_time = time.time()
142142
input_path = data_folder / "TinyStoriesV2-GPT4-train.txt"

0 commit comments

Comments
 (0)