22
33import numpy as np
44import pytest
5- from transformers import PreTrainedTokenizerFast
5+ from tokenizers import Tokenizer
66
77from model2vec import StaticModel
88
99
10- def test_initialization (
11- mock_vectors : np .ndarray , mock_tokenizer : PreTrainedTokenizerFast , mock_config : dict [str , str ]
12- ) -> None :
10+ def test_initialization (mock_vectors : np .ndarray , mock_tokenizer : Tokenizer , mock_config : dict [str , str ]) -> None :
1311 """Test successful initialization of StaticModel."""
1412 model = StaticModel (vectors = mock_vectors , tokenizer = mock_tokenizer , config = mock_config )
15- assert model .vectors .shape == (5 , 2 )
13+ assert model .embedding . weight .shape == (5 , 2 )
1614 assert len (model .tokens ) == 5
1715 assert model .tokenizer == mock_tokenizer
1816 assert model .config == mock_config
1917
2018
21- def test_initialization_token_vector_mismatch (
22- mock_tokenizer : PreTrainedTokenizerFast , mock_config : dict [str , str ]
23- ) -> None :
19+ def test_initialization_token_vector_mismatch (mock_tokenizer : Tokenizer , mock_config : dict [str , str ]) -> None :
2420 """Test if error is raised when number of tokens and vectors don't match."""
2521 mock_vectors = np .array ([[0.1 , 0.2 ], [0.2 , 0.3 ]])
2622 with pytest .raises (ValueError ):
2723 StaticModel (vectors = mock_vectors , tokenizer = mock_tokenizer , config = mock_config )
2824
2925
3026def test_encode_single_sentence (
31- mock_vectors : np .ndarray , mock_tokenizer : PreTrainedTokenizerFast , mock_config : dict [str , str ]
27+ mock_vectors : np .ndarray , mock_tokenizer : Tokenizer , mock_config : dict [str , str ]
3228) -> None :
3329 """Test encoding of a single sentence."""
3430 model = StaticModel (vectors = mock_vectors , tokenizer = mock_tokenizer , config = mock_config )
@@ -37,7 +33,7 @@ def test_encode_single_sentence(
3733
3834
3935def test_encode_multiple_sentences (
40- mock_vectors : np .ndarray , mock_tokenizer : PreTrainedTokenizerFast , mock_config : dict [str , str ]
36+ mock_vectors : np .ndarray , mock_tokenizer : Tokenizer , mock_config : dict [str , str ]
4137) -> None :
4238 """Test encoding of multiple sentences."""
4339 model = StaticModel (vectors = mock_vectors , tokenizer = mock_tokenizer , config = mock_config )
@@ -46,24 +42,29 @@ def test_encode_multiple_sentences(
4642
4743
4844def test_encode_empty_sentence (
49- mock_vectors : np .ndarray , mock_tokenizer : PreTrainedTokenizerFast , mock_config : dict [str , str ]
45+ mock_vectors : np .ndarray , mock_tokenizer : Tokenizer , mock_config : dict [str , str ]
5046) -> None :
5147 """Test encoding with an empty sentence."""
5248 model = StaticModel (vectors = mock_vectors , tokenizer = mock_tokenizer , config = mock_config )
5349 encoded = model .encode ("" )
5450 assert np .array_equal (encoded , np .zeros ((2 ,)))
5551
5652
57- def test_normalize () -> None :
53+ def test_normalize (mock_vectors : np . ndarray , mock_tokenizer : Tokenizer , mock_config : dict [ str , str ] ) -> None :
5854 """Test normalization of vectors."""
59- X = np .array ([[3 , 4 ], [1 , 2 ], [0 , 0 ]])
60- normalized = StaticModel .normalize (X )
61- expected = np .array ([[0.6 , 0.8 ], [0.4472136 , 0.89442719 ], [0 , 0 ]])
55+ s = "word1 word2 word3"
56+ model = StaticModel (vectors = mock_vectors , tokenizer = mock_tokenizer , config = mock_config , normalize = False )
57+ X = model .encode (s )
58+ model = StaticModel (vectors = mock_vectors , tokenizer = mock_tokenizer , config = mock_config , normalize = True )
59+ normalized = model .encode (s )
60+
61+ expected = X / np .linalg .norm (X )
62+
6263 np .testing .assert_almost_equal (normalized , expected )
6364
6465
6566def test_save_pretrained (
66- tmp_path : Path , mock_vectors : np .ndarray , mock_tokenizer : PreTrainedTokenizerFast , mock_config : dict [str , str ]
67+ tmp_path : Path , mock_vectors : np .ndarray , mock_tokenizer : Tokenizer , mock_config : dict [str , str ]
6768) -> None :
6869 """Test saving a pretrained model."""
6970 model = StaticModel (vectors = mock_vectors , tokenizer = mock_tokenizer , config = mock_config )
@@ -80,7 +81,7 @@ def test_save_pretrained(
8081
8182
8283def test_load_pretrained (
83- tmp_path : Path , mock_vectors : np .ndarray , mock_tokenizer : PreTrainedTokenizerFast , mock_config : dict [str , str ]
84+ tmp_path : Path , mock_vectors : np .ndarray , mock_tokenizer : Tokenizer , mock_config : dict [str , str ]
8485) -> None :
8586 """Test loading a pretrained model after saving it."""
8687 # Save the model to a temporary path
@@ -92,6 +93,33 @@ def test_load_pretrained(
9293 loaded_model = StaticModel .from_pretrained (save_path )
9394
9495 # Assert that the loaded model has the same properties as the original one
95- np .testing .assert_array_equal (loaded_model .vectors , mock_vectors )
96+ np .testing .assert_array_equal (loaded_model .embedding . weight . numpy () , mock_vectors )
9697 assert loaded_model .tokenizer .get_vocab () == mock_tokenizer .get_vocab ()
9798 assert loaded_model .config == mock_config
99+
100+
101+ def test_initialize_normalize (mock_vectors : np .ndarray , mock_tokenizer : Tokenizer ) -> None :
102+ """Tests whether the normalization initialization is correct."""
103+ model = StaticModel (mock_vectors , mock_tokenizer , {}, normalize = None )
104+ assert not model .normalize
105+
106+ model = StaticModel (mock_vectors , mock_tokenizer , {}, normalize = False )
107+ assert not model .normalize
108+
109+ model = StaticModel (mock_vectors , mock_tokenizer , {}, normalize = True )
110+ assert model .normalize
111+
112+ model = StaticModel (mock_vectors , mock_tokenizer , {"normalize" : False }, normalize = True )
113+ assert model .normalize
114+
115+ model = StaticModel (mock_vectors , mock_tokenizer , {"normalize" : True }, normalize = False )
116+ assert not model .normalize
117+
118+
119+ def test_set_normalize (mock_vectors : np .ndarray , mock_tokenizer : Tokenizer ) -> None :
120+ """Tests whether the normalize is set correctly."""
121+ model = StaticModel (mock_vectors , mock_tokenizer , {}, normalize = True )
122+ model .normalize = False
123+ assert model .config == {"normalize" : False }
124+ model .normalize = True
125+ assert model .config == {"normalize" : True }
0 commit comments