Skip to content

Commit fe1b450

Browse files
committed
Update tests
1 parent a5727a7 commit fe1b450

2 files changed

Lines changed: 31 additions & 1 deletion

File tree

model2vec/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def __init__(
4949
self.unk_token_id = None
5050

5151
self.config = config
52-
self.normalize = config.get("normalize", normalize if normalize is not None else False)
52+
if normalize is not None:
53+
self.normalize = normalize
54+
else:
55+
self.normalize = config.get("normalize", False)
5356

5457
@property
5558
def normalize(self) -> bool:

tests/test_model.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,30 @@ def test_load_pretrained(
9696
np.testing.assert_array_equal(loaded_model.embedding.weight.numpy(), mock_vectors)
9797
assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab()
9898
assert loaded_model.config == mock_config
99+
100+
101+
def test_initialize_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
102+
"""Tests whether the normalization initialization is correct."""
103+
model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=None)
104+
assert not model.normalize
105+
106+
model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=False)
107+
assert not model.normalize
108+
109+
model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=True)
110+
assert model.normalize
111+
112+
model = StaticModel(mock_vectors, mock_tokenizer, {"normalize": False}, normalize=True)
113+
assert model.normalize
114+
115+
model = StaticModel(mock_vectors, mock_tokenizer, {"normalize": True}, normalize=False)
116+
assert not model.normalize
117+
118+
119+
def test_set_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
120+
"""Tests whether the normalize is set correctly."""
121+
model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=True)
122+
model.normalize = False
123+
assert model.config == {"normalize": False}
124+
model.normalize = True
125+
assert model.config == {"normalize": True}

0 commit comments

Comments
 (0)