@@ -96,3 +96,30 @@ def test_load_pretrained(
9696 np .testing .assert_array_equal (loaded_model .embedding .weight .numpy (), mock_vectors )
9797 assert loaded_model .tokenizer .get_vocab () == mock_tokenizer .get_vocab ()
9898 assert loaded_model .config == mock_config
99+
100+
101+ def test_initialize_normalize (mock_vectors : np .ndarray , mock_tokenizer : Tokenizer ) -> None :
102+ """Tests whether the normalization initialization is correct."""
103+ model = StaticModel (mock_vectors , mock_tokenizer , {}, normalize = None )
104+ assert not model .normalize
105+
106+ model = StaticModel (mock_vectors , mock_tokenizer , {}, normalize = False )
107+ assert not model .normalize
108+
109+ model = StaticModel (mock_vectors , mock_tokenizer , {}, normalize = True )
110+ assert model .normalize
111+
112+ model = StaticModel (mock_vectors , mock_tokenizer , {"normalize" : False }, normalize = True )
113+ assert model .normalize
114+
115+ model = StaticModel (mock_vectors , mock_tokenizer , {"normalize" : True }, normalize = False )
116+ assert not model .normalize
117+
118+
119+ def test_set_normalize (mock_vectors : np .ndarray , mock_tokenizer : Tokenizer ) -> None :
120+ """Tests whether the normalize is set correctly."""
121+ model = StaticModel (mock_vectors , mock_tokenizer , {}, normalize = True )
122+ model .normalize = False
123+ assert model .config == {"normalize" : False }
124+ model .normalize = True
125+ assert model .config == {"normalize" : True }
0 commit comments